diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 4a94c6d82..47f08aa96 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -1,8 +1,11 @@ import gzip +import logging +import os import time import uuid from typing import Any, Dict, List, Optional +import zmq from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -17,6 +20,8 @@ grab_exact_from_heterogeneous_queue, ) +logger = logging.getLogger(__name__) + # Constants MIN_ENV_WEIGHT = ( 0.01 # Minimum weight to prevent environments from being completely starved @@ -245,6 +250,19 @@ class Info(BaseModel): batch_size: int = -1 +def send_to_sidecar(payload: Dict[str, Any], port: int): + """Helper to send payload to ZMQ sidecar (PUSH).""" + try: + context = zmq.Context() + socket = context.socket(zmq.PUSH) + socket.connect(f"tcp://localhost:{port}") + socket.send_pyobj(payload) + socket.close() + context.term() + except Exception as e: + logger.error(f"Failed to send payload to sidecar: {e}") + + @app.post("/register") async def register(registration: Registration): # Initialize app state if not already done @@ -263,6 +281,38 @@ async def register(registration: Registration): app.state.envs = [] app.state.buffer = {} # Buffer for mixed-size groups per environment + # Init ZMQ config + zmq_port_str = os.getenv("ATROPOS_ZMQ_PORT", "5555") + zmq_port = int(zmq_port_str) + app.state.zmq_port = zmq_port + + if registration.wandb_project: + # Resume Logic: + # hash of the group name saved locally so that if server crashes + # or some other issues happens we can keep using the same run for the env isntances + import hashlib + + run_id = hashlib.md5(registration.wandb_group.encode()).hexdigest() + + # Generate config for sidecar + wandb_config = { + "project": registration.wandb_project, + "group": registration.wandb_group, + "name": f"API_Aggregator_{registration.wandb_group}", + "id": run_id, + "resume": "allow", + "config": registration.model_dump(), + "reinit": True, + } + + # Send INIT command to sidecar + send_to_sidecar({"_type": "init", "config": wandb_config}, zmq_port) + + # Store run ID for /wandb_info + app.state.wandb_run_id = run_id + else: + app.state.wandb_run_id = None + # Initialize requesters list if not already done if not hasattr(app.state, "requesters"): app.state.requesters = [] @@ -273,23 +323,35 @@ async def register(registration: Registration): @app.post("/register-env") async def register_env_url(register_env: RegisterEnv): - # Check if trainer has started if not hasattr(app.state, "started") or not app.state.started: - return { - "status": "wait for trainer to start", - } + return {"status": "wait for trainer to start"} - # Initialize envs list if not already done if not hasattr(app.state, "envs"): app.state.envs = [] + if not hasattr(app.state, "env_leaders"): + app.state.env_leaders = {} + if not hasattr(app.state, "next_leader_port"): + app.state.next_leader_port = 5600 - # Get checkpoint directory safely checkpoint_dir = getattr(app.state, "checkpoint_dir", "") - real_name = ( - f"{register_env.desired_name}_" - f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}" + instance_index = len( + [x for x in app.state.envs if x["desired_name"] == register_env.desired_name] ) + real_name = f"{register_env.desired_name}_{instance_index}" registered_id = len(app.state.envs) + + is_leader = register_env.desired_name not in app.state.env_leaders + leader_receive_port = None + + if is_leader: + leader_receive_port = app.state.next_leader_port + app.state.next_leader_port += 1 + app.state.env_leaders[register_env.desired_name] = { + "instance": real_name, + "env_id": registered_id, + "receive_port": leader_receive_port, + } + app.state.envs.append( { "max_context_len": register_env.max_token_length, @@ -301,8 +363,21 @@ async def register_env_url(register_env: RegisterEnv): "connected": True, "min_batch_allocation": register_env.min_batch_allocation, "group_size": register_env.group_size, + "is_leader": is_leader, } ) + + if hasattr(app.state, "zmq_port"): + msg = { + "_type": "env_register", + "env_type": register_env.desired_name, + "instance": real_name, + "is_leader": is_leader, + } + if is_leader: + msg["leader_receive_port"] = leader_receive_port + send_to_sidecar(msg, app.state.zmq_port) + return { "status": "success", "env_id": registered_id, @@ -311,13 +386,33 @@ async def register_env_url(register_env: RegisterEnv): "starting_step": app.state.status_dict["step"], "checkpoint_interval": app.state.save_checkpoint_interval, "num_steps": app.state.num_steps, + "is_leader": is_leader, + "leader_receive_port": leader_receive_port, + "wandb_project": getattr(app.state, "project", None), + "wandb_group": getattr(app.state, "group", None), } @app.post("/disconnect-env") async def disconnect_env(disconnect_env: EnvIdentifier): try: - app.state.envs[disconnect_env.env_id]["connected"] = False + env = app.state.envs[disconnect_env.env_id] + env["connected"] = False + + if hasattr(app.state, "zmq_port"): + send_to_sidecar( + { + "_type": "env_disconnect", + "env_type": env["desired_name"], + "instance": env["real_name"], + "was_leader": env.get("is_leader", False), + }, + app.state.zmq_port, + ) + + if env.get("is_leader") and hasattr(app.state, "env_leaders"): + app.state.env_leaders.pop(env["desired_name"], None) + return {"status": "success"} except (AttributeError, IndexError) as e: return {"status": "failure", "error": str(e)} @@ -326,9 +421,19 @@ async def disconnect_env(disconnect_env: EnvIdentifier): @app.get("/wandb_info") async def wandb_info(): try: - return {"group": app.state.group, "project": app.state.project} + return { + "group": app.state.group, + "project": app.state.project, + "zmq_port": getattr(app.state, "zmq_port", None), + "wandb_run_id": getattr(app.state, "wandb_run_id", None), + } except AttributeError: - return {"group": None, "project": None} + return { + "group": None, + "project": None, + "zmq_port": None, + "wandb_run_id": None, + } @app.get("/info") @@ -409,7 +514,60 @@ async def get_latest_example(): @app.post("/scored_data") async def scored_data(scored_data: ScoredData): - return _process_scored_data(scored_data) + data_dict = { + "tokens": scored_data.tokens, + "masks": scored_data.masks, + "scores": scored_data.scores, + "advantages": scored_data.advantages, + "ref_logprobs": scored_data.ref_logprobs, + "messages": scored_data.messages, + "generation_params": scored_data.generation_params, + "inference_logprobs": scored_data.inference_logprobs, + "overrides": scored_data.overrides, + "group_overrides": scored_data.group_overrides, + "images": scored_data.images, + "env_id": scored_data.env_id, + } + # Check if this is a mixed-size group + env_id = scored_data.env_id + if env_id is not None and env_id < len(app.state.envs): + expected_group_size = app.state.envs[env_id].get("group_size", 1) + actual_group_size = len(scored_data.tokens) + + if actual_group_size != expected_group_size: + # Mixed size group - add to buffer + if env_id not in app.state.buffer: + app.state.buffer[env_id] = [] + + app.state.buffer[env_id].append(data_dict) + + # Try to find groups that sum to expected_group_size + indices = find_groups_summing_to_target( + app.state.buffer[env_id], expected_group_size + ) + + if indices: + # Add these groups to queue in order + groups_to_add = [] + for idx in sorted(indices, reverse=True): + groups_to_add.append(app.state.buffer[env_id].pop(idx)) + + # Add in FIFO order + for group in reversed(groups_to_add): + app.state.queue.append(group) + app.state.latest = group + + return { + "status": "buffered", + "buffer_size": sum( + len(g["tokens"]) for g in app.state.buffer.get(env_id, []) + ), + } + + # Normal path - correct size or no env info + app.state.queue.append(data_dict) + app.state.latest = data_dict + return {"status": "received"} @app.post("/scored_data_list") @@ -564,6 +722,10 @@ async def get_status_env(env: EnvIdentifier): @app.get("/reset_data") async def reset_data(): try: + # Send RESET to sidecar + if hasattr(app.state, "zmq_port"): + send_to_sidecar({"_type": "reset"}, app.state.zmq_port) + del app.state.queue app.state.group = None app.state.project = None diff --git a/atroposlib/api/sidecar.py b/atroposlib/api/sidecar.py new file mode 100644 index 000000000..7ba6422bf --- /dev/null +++ b/atroposlib/api/sidecar.py @@ -0,0 +1,259 @@ +import argparse +import logging +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import zmq + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger("ZMQSidecar") + +AGGREGATION_TIMEOUT = 60.0 + + +class ZMQLogAggregator: + """ + Sidecar service that aggregates metrics from multiple environment instances + by (step, env_type) and routes aggregated data to the leader instance for + each env_type. The leader is responsible for logging to wandb. + """ + + def __init__(self, port: int = 5555, context: Optional[zmq.Context] = None): + self.port = port + self.context = context or zmq.Context() + self.socket = self.context.socket(zmq.PULL) + self.running = False + + self.registered_envs: Dict[str, Set[str]] = defaultdict(set) + self.pending_metrics: Dict[ + Tuple[int, str], Dict[str, List[Tuple[str, Any]]] + ] = {} + self.env_reported: Dict[Tuple[int, str], Set[str]] = defaultdict(set) + self.pending_timestamps: Dict[Tuple[int, str], float] = {} + + # Leader info per env_type: {env_type: {"port": int, "socket": zmq.Socket}} + self.leaders: Dict[str, Dict[str, Any]] = {} + + def start(self): + if self.running: + return + + try: + self.socket.bind(f"tcp://*:{self.port}") + logger.info(f"ZMQLogAggregator listening on port {self.port}") + except zmq.ZMQError as e: + logger.error(f"Failed to bind ZMQ socket on port {self.port}: {e}") + raise + + self.running = True + self._loop() + + def stop(self): + self.running = False + try: + self.socket.close() + except Exception: + pass + for leader_info in self.leaders.values(): + try: + leader_info.get("socket", None).close() + except Exception: + pass + + def _connect_to_leader(self, env_type: str, port: int): + """Create a ZMQ PUSH socket to send aggregated data to the leader.""" + if env_type in self.leaders: + return + + try: + socket = self.context.socket(zmq.PUSH) + socket.setsockopt(zmq.SNDHWM, 10000) + socket.setsockopt(zmq.LINGER, 1000) + socket.connect(f"tcp://localhost:{port}") + self.leaders[env_type] = {"port": port, "socket": socket} + logger.info(f"Connected to leader for {env_type} on port {port}") + except Exception as e: + logger.error(f"Failed to connect to leader for {env_type}: {e}") + + def _handle_control_message(self, payload: Dict[str, Any]): + msg_type = payload.get("_type") + + if msg_type == "env_register": + env_type = payload.get("env_type") + instance = payload.get("instance") + is_leader = payload.get("is_leader", False) + leader_receive_port = payload.get("leader_receive_port") + + if env_type and instance: + self.registered_envs[env_type].add(instance) + logger.info(f"Registered {instance} for {env_type}") + + if is_leader and leader_receive_port: + self._connect_to_leader(env_type, leader_receive_port) + + elif msg_type == "env_disconnect": + env_type = payload.get("env_type") + instance = payload.get("instance") + was_leader = payload.get("was_leader", False) + + if env_type and instance: + self.registered_envs[env_type].discard(instance) + logger.info(f"Disconnected {instance} from {env_type}") + self._check_pending_after_disconnect(env_type) + + if was_leader and env_type in self.leaders: + try: + self.leaders[env_type]["socket"].close() + except Exception: + pass + del self.leaders[env_type] + logger.info(f"Removed leader connection for {env_type}") + + def _check_pending_after_disconnect(self, env_type: str): + keys_to_check = [k for k in self.pending_metrics if k[1] == env_type] + for key in keys_to_check: + if self._all_reported(key): + self._aggregate_and_send(key) + + def _all_reported(self, key: Tuple[int, str]) -> bool: + step, env_type = key + expected = self.registered_envs.get(env_type, set()) + reported = self.env_reported.get(key, set()) + return bool(expected) and reported >= expected + + def _handle_log_payload(self, payload: Dict[str, Any]): + step = payload.pop("_step", None) + env_type = payload.pop("_env_type", None) + instance = payload.pop("_instance", None) + + if env_type is None or instance is None: + logger.warning("Received log without env_type or instance, dropping") + return + + key = (step, env_type) + + if key not in self.pending_metrics: + self.pending_metrics[key] = defaultdict(list) + self.pending_timestamps[key] = time.time() + + for metric_name, value in payload.items(): + self.pending_metrics[key][metric_name].append((instance, value)) + + self.env_reported[key].add(instance) + + if self._all_reported(key): + self._aggregate_and_send(key) + + def _aggregate_and_send(self, key: Tuple[int, str]): + """Aggregate metrics and send to the leader for this env_type.""" + step, env_type = key + metrics = self.pending_metrics.pop(key, {}) + self.env_reported.pop(key, None) + self.pending_timestamps.pop(key, None) + + if not metrics: + return + + final_metrics = {} + + for metric_name, values in metrics.items(): + for instance, value in values: + final_metrics[f"{env_type}/instances/{instance}/{metric_name}"] = value + + numeric_values = [v for _, v in values if isinstance(v, (int, float))] + if numeric_values: + final_metrics[f"{env_type}/aggregated/{metric_name}_mean"] = np.mean( + numeric_values + ) + final_metrics[f"{env_type}/aggregated/{metric_name}_std"] = np.std( + numeric_values + ) + final_metrics[f"{env_type}/aggregated/{metric_name}_min"] = np.min( + numeric_values + ) + final_metrics[f"{env_type}/aggregated/{metric_name}_max"] = np.max( + numeric_values + ) + + if not final_metrics: + return + + final_metrics["_step"] = step + + leader_info = self.leaders.get(env_type) + if leader_info and leader_info.get("socket"): + try: + leader_info["socket"].send_pyobj(final_metrics, flags=zmq.NOBLOCK) + logger.debug( + f"Sent aggregated metrics for {env_type} step {step} to leader" + ) + except zmq.Again: + logger.warning( + f"Leader buffer full for {env_type}, dropping aggregated data" + ) + except Exception as e: + logger.error(f"Failed to send to leader for {env_type}: {e}") + else: + logger.warning( + f"No leader connected for {env_type}, dropping aggregated data" + ) + + def _check_timeouts(self): + now = time.time() + stale_keys = [ + k + for k, ts in self.pending_timestamps.items() + if now - ts > AGGREGATION_TIMEOUT + ] + for key in stale_keys: + step, env_type = key + logger.warning(f"Timeout for {env_type} step {step}, sending partial data") + self._aggregate_and_send(key) + + def _loop(self): + poller = zmq.Poller() + poller.register(self.socket, zmq.POLLIN) + logger.info("ZMQ Sidecar loop started") + + last_timeout_check = time.time() + + while self.running: + try: + socks = dict(poller.poll(1000)) + + if self.socket in socks: + payload = self.socket.recv_pyobj() + + if isinstance(payload, dict) and "_type" in payload: + self._handle_control_message(payload) + else: + self._handle_log_payload(payload) + + if time.time() - last_timeout_check > 10: + self._check_timeouts() + last_timeout_check = time.time() + + except Exception as e: + logger.error(f"Error in ZMQLogAggregator loop: {e}") + + +def main(): + parser = argparse.ArgumentParser(description="Atropos ZMQ Logging Sidecar") + parser.add_argument("--port", type=int, default=5555, help="Port to listen on") + args = parser.parse_args() + + aggregator = ZMQLogAggregator(port=args.port) + try: + aggregator.start() + except KeyboardInterrupt: + logger.info("Stopping ZMQ Sidecar...") + aggregator.stop() + + +if __name__ == "__main__": + main() diff --git a/atroposlib/cli/inference_node_wandb_watcher.py b/atroposlib/cli/inference_node_wandb_watcher.py index e21e1c43f..2d19c2477 100644 --- a/atroposlib/cli/inference_node_wandb_watcher.py +++ b/atroposlib/cli/inference_node_wandb_watcher.py @@ -1,43 +1,66 @@ import argparse import time +from urllib.parse import urlparse import requests import wandb - -def update_wandb(health_statuses): - wandb.log(health_statuses) +from atroposlib.utils.logging_client import ZMQLogger def run(api_addr, tp, node_num): print(f"Starting up with {api_addr}, {tp}, {node_num}", flush=True) + zmq_logger = None + while True: try: data = requests.get(f"{api_addr}/wandb_info").json() - wandb_group = data["group"] - wandb_project = data["project"] + wandb_group = data.get("group") + wandb_project = data.get("project") + zmq_port = data.get("zmq_port") except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): wandb_project = None wandb_group = None + zmq_port = None print("Waiting for init...") if wandb_project is None: time.sleep(1) else: + if zmq_port: + try: + parsed = urlparse(api_addr) + host = parsed.hostname or "localhost" + zmq_addr = f"tcp://{host}:{zmq_port}" + zmq_logger = ZMQLogger(address=zmq_addr) + print(f"Connected to ZMQ Logger at {zmq_addr}") + break + except Exception as e: + print(f"Failed to connect ZMQ: {e}") + # does our existing/old wandb setup if zmq isn't open + wandb.init( project=wandb_project, group=wandb_group, name=f"inf_node_{node_num}" ) break + curr_step = 0 health_statuses = { f"server/server_health_{node_num}_{i}": 0.0 for i in range(8 // tp) } while True: - data = requests.get(f"{api_addr}/status").json() - step = data["current_step"] - if step > curr_step: - wandb.log(health_statuses, step=step) - curr_step = step + try: + data = requests.get(f"{api_addr}/status").json() + step = data["current_step"] + if step > curr_step: + if zmq_logger: + zmq_logger.log(health_statuses, step=step) + else: + wandb.log(health_statuses, step=step) + curr_step = step + except Exception as e: + print(f"Error fetching status: {e}") + time.sleep(60) # Check on each server for i in range(8 // tp): diff --git a/atroposlib/cli/run_api.py b/atroposlib/cli/run_api.py index 0e5e0a402..ca9b74809 100644 --- a/atroposlib/cli/run_api.py +++ b/atroposlib/cli/run_api.py @@ -1,12 +1,34 @@ """ -Run the Trajectory API server. +Run the Trajectory API server and the ZMQ Sidecar process. """ import argparse +import multiprocessing +import os +import signal +import sys import uvicorn +def run_sidecar(port: int): + """ + Run the ZMQ sidecar process. + """ + from atroposlib.api.sidecar import main as sidecar_main + + # Set the process title if possible + try: + import setproctitle + + setproctitle.setproctitle("atropos-zmq-sidecar") + except ImportError: + pass + + sys.argv = ["atropos-sidecar", "--port", str(port)] + sidecar_main() + + def main(): """ Run the API server. @@ -14,16 +36,42 @@ def main(): host: The host to run the API server on. port: The port to run the API server on. reload: Whether to reload the API server on code changes. + zmq_port: The port to run the ZMQ sidecar on. """ parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--zmq-port", type=int, default=5555) parser.add_argument("--reload", action="store_true") args = parser.parse_args() - uvicorn.run( - "atroposlib.api:app", host=args.host, port=args.port, reload=args.reload + + # Set the ZMQ port environment variable for the API server to discover + os.environ["ATROPOS_ZMQ_PORT"] = str(args.zmq_port) + + # Start the ZMQ sidecar process + sidecar_process = multiprocessing.Process( + target=run_sidecar, args=(args.zmq_port,), daemon=True ) + sidecar_process.start() + print(f"Started ZMQ sidecar on port {args.zmq_port} (pid={sidecar_process.pid})") + + try: + + uvicorn.run( + "atroposlib.api:app", host=args.host, port=args.port, reload=args.reload + ) + except KeyboardInterrupt: + print("Stopping API server...") + finally: + if sidecar_process.is_alive(): + print("Stopping ZMQ sidecar...") + sidecar_process.terminate() + sidecar_process.join(timeout=2) + if sidecar_process.is_alive(): + sidecar_process.kill() if __name__ == "__main__": + + signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit(0)) main() diff --git a/atroposlib/envs/README.md b/atroposlib/envs/README.md index 7b94d0aae..531dcf121 100644 --- a/atroposlib/envs/README.md +++ b/atroposlib/envs/README.md @@ -198,3 +198,28 @@ These class-level variables in `BaseEnv` can be overridden in your subclass to c * **CLI Integration**: Provides a `cli()` class method using `pydantic-cli` to easily create command-line interfaces for your environment (e.g., `python your_env_module.py serve --port 8001 ...`). See `get_cli_serve_config_cls` and `get_cli_process_config_cls`. By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the `Atropos` framework. + +--- + +## Weave tracing in environments + +Environments emit Weave traces to help you inspect rollout flow and LLM calls: + +- Enabled by default; disable via config or env: + - Config: under your OpenAI server settings, set `tracing_enabled: false` + - YAML example: + ```yaml + openai: + model_name: your-model + base_url: http://localhost:9000 + tracing_enabled: false + ``` + - CLI example: `--openai--tracing_enabled false` + - Env (hard disable): `WEAVE_DISABLED=true` +- Optional project override via `WEAVE_PROJECT=`. +- Traces include: + - Environment operations (group collection and send-to-API) + - LLM calls to OpenAI-compatible providers (including SGLang/TRL vLLM wrappers) + - Useful attributes (env name/id, model name, base URL, group/batch sizes) + +View your traces at `https://weave.wandb.ai` under your project. diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index fda844e60..c677fa30d 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -10,6 +10,7 @@ import uuid import warnings from abc import ABC, abstractmethod +from contextlib import nullcontext from datetime import datetime from enum import Enum from pathlib import Path @@ -27,6 +28,11 @@ from transformers import AutoTokenizer from typing_extensions import TypedDict +try: + import weave +except ImportError: + weave = None + from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE from atroposlib.envs.server_handling.openai_server import resolve_openai_configs from atroposlib.frontend.jsonl2html import generate_html @@ -38,9 +44,15 @@ merge_dicts, ) from atroposlib.utils.io import parse_http_response +from atroposlib.utils.logging_client import ( + ZMQLogger, + ZMQLogReceiver, + setup_weave_for_worker, +) from atroposlib.utils.metrics import get_std_min_max_avg from ..type_definitions import Item, Message +from .server_handling.server_baseline import weave_op from .server_handling.server_manager import ( APIServer, APIServerConfig, @@ -225,6 +237,9 @@ def __init__( self.wandb_prepend = None self.checkpoint_dir = "" self.checkpoint_interval = -1 + self.zmq_logger = None + self.is_leader = False + self.log_receiver = None if self.config.data_path_to_save_groups is not None: Path(self.config.data_path_to_save_groups).parent.mkdir( parents=True, exist_ok=True @@ -287,6 +302,7 @@ async def collect_trajectory( "Handle env single method must be implemented in subclass " ) + @weave_op async def collect_trajectories(self, item: Item) -> Tuple[ Union[ Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None] @@ -425,6 +441,7 @@ async def setup(self): async def setup_wandb(self): if self.config.use_wandb: + zmq_port = None # Setup wandb getting the group and project via the server while self.wandb_project is None: async with aiohttp.ClientSession() as session: @@ -434,27 +451,53 @@ async def setup_wandb(self): data = await parse_http_response(resp, logger) self.wandb_group = data["group"] self.wandb_project = data["project"] + zmq_port = data.get("zmq_port") + self.wandb_run_id = data.get("wandb_run_id") if self.wandb_project is None: await asyncio.sleep(1) continue - wandb_run_name = None - if self.config.wandb_name: - random_id = "".join(random.choices(string.ascii_lowercase, k=6)) - current_date = datetime.now().strftime("%Y-%m-%d") - wandb_run_name = ( - f"{self.config.wandb_name}-{current_date}-{random_id}" - ) + if zmq_port: + try: + from urllib.parse import urlparse - wandb.init( - name=wandb_run_name, - project=self.wandb_project, - group=self.wandb_group, - config=self.config.model_dump(), - ) + parsed_url = urlparse(self.config.rollout_server_url) + server_host = parsed_url.hostname or "localhost" + + zmq_addr = f"tcp://{server_host}:{zmq_port}" + self.zmq_logger = ZMQLogger(address=zmq_addr) + logger.info(f"Using ZMQ Logger connected to {zmq_addr}") + except Exception as e: + logger.error(f"Failed to init ZMQ Logger: {e}") + self.zmq_logger = None break + def _init_wandb_for_leader(self): + """Initialize wandb for a leader instance. Called after registration.""" + if not self.is_leader or not self.config.use_wandb: + return + + wandb_run_name = None + if self.config.wandb_name: + random_id = "".join(random.choices(string.ascii_lowercase, k=6)) + current_date = datetime.now().strftime("%Y-%m-%d") + wandb_run_name = f"{self.config.wandb_name}-{current_date}-{random_id}" + + if weave is not None: + try: + weave.init(self.wandb_project) + except Exception as e: + logger.warning(f"Failed to initialize Weave: {e}") + + wandb.init( + name=wandb_run_name, + project=self.wandb_project, + group=self.wandb_group, + config=self.config.model_dump(), + ) + logger.info(f"Leader {self.wandb_prepend} initialized wandb") + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10), @@ -493,6 +536,29 @@ async def register_env(self): self.curr_step = data["starting_step"] self.checkpoint_dir = data["checkpoint_dir"] self.checkpoint_interval = data["checkpoint_interval"] + self.is_leader = data.get("is_leader", False) + leader_receive_port = data.get("leader_receive_port") + + if self.is_leader and leader_receive_port: + self.log_receiver = ZMQLogReceiver(port=leader_receive_port) + logger.info( + f"Leader {self.wandb_prepend} listening on port {leader_receive_port}" + ) + self._init_wandb_for_leader() + + if self.zmq_logger is not None and self.wandb_project: + setup_weave_for_worker( + self.wandb_project, + group_name=self.wandb_prepend, + run_id=getattr(self, "wandb_run_id", None), + ) + if weave is not None: + try: + weave.init(self.wandb_project) + except Exception as e: + logger.warning( + f"Failed to initialize Weave in register_env: {e}" + ) if self.config.total_steps == -1: self.config.total_steps = data["num_steps"] if self.config.total_steps == -1: @@ -501,7 +567,8 @@ async def register_env(self): f"Initialized env with id {self.env_id}: " f"curr_step: {self.curr_step}, " f"checkpoint_dir: {self.checkpoint_dir}, " - f"checkpoint_interval: {self.checkpoint_interval}" + f"checkpoint_interval: {self.checkpoint_interval}, " + f"is_leader: {self.is_leader}" ) if self.curr_step > 0: self.load_checkpoint() @@ -634,13 +701,35 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None): self.rollouts_for_wandb = [] self.completion_lengths = [] if self.config.use_wandb: - if self.wandb_prepend is not None: - wandb_metrics = { - f"{self.wandb_prepend}_{k}": v for k, v in wandb_metrics.items() - } - # add server metrics to wandb without prepend to collate them all wandb_metrics.update(server_wandb_metrics) - wandb.log(wandb_metrics, step=self.curr_step) + if self.zmq_logger is not None: + self.zmq_logger.log( + wandb_metrics, + step=self.curr_step, + env_type=self.config.wandb_name, + instance_name=self.wandb_prepend, + ) + else: + if self.wandb_prepend is not None: + wandb_metrics = { + f"{self.wandb_prepend}_{k}": v for k, v in wandb_metrics.items() + } + wandb.log(wandb_metrics, step=self.curr_step) + + def process_aggregated_logs(self): + """ + Process aggregated log data received from the sidecar (leader only). + Called periodically in the main loop. + """ + if not self.is_leader or self.log_receiver is None: + return + + while True: + data = self.log_receiver.recv_nowait() + if data is None: + break + step = data.pop("_step", None) + wandb.log(data, step=step) async def evaluate_log( self, @@ -766,6 +855,7 @@ async def _send_scored_data_to_api(self, scored_data): if isinstance(scored_data, list) else f"{self.config.rollout_server_url}/scored_data" ) + async with aiohttp.ClientSession() as session: async with self._post_json_with_compression( session, @@ -803,6 +893,10 @@ def _post_json_with_compression( return session.post(url, data=body, headers=headers) + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, max=10), + ) async def handle_send_to_api( self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], @@ -898,6 +992,7 @@ async def handle_send_to_api( ) print(f"Failed to send {data_type_str} after retries: {e}") + @weave_op async def handle_env( self, item_uuid: str ) -> Optional[Union[ScoredDataGroup, List[ScoredDataGroup]]]: @@ -912,7 +1007,19 @@ async def handle_env( logger.debug(f"handle_env: Starting with item: {item}") # do a rollout with item try: - to_postprocess, to_backlog = await self.collect_trajectories(item) + if weave is not None and getattr(self, "env_id", None) is not None: + ctx = weave.attributes( + { + "env_id": self.env_id, + "env_name": self.wandb_prepend, + "env_type": self.config.wandb_name, + } + ) + else: + ctx = nullcontext() + + with ctx: + to_postprocess, to_backlog = await self.collect_trajectories(item) except Exception as e: logging.error(f"Error in collect_trajectories: {e}") to_postprocess = None @@ -1206,6 +1313,8 @@ async def env_manager(self): self.running_items.pop(item_uuid) # Do we want to retry? probably not... # self.backlog.append(item["item"]) + # Process aggregated logs if this is a leader + self.process_aggregated_logs() await asyncio.sleep(0.1) async def process_manager(self): @@ -1215,15 +1324,71 @@ async def process_manager(self): await self.setup() if self.config.use_wandb: - random_id = "".join(random.choices(string.ascii_lowercase, k=6)) - current_date = datetime.now().strftime("%Y-%m-%d") - wandb_run_name = f"{self.name}-{current_date}-{random_id}" - wandb.init( - project=self.wandb_project, - name=wandb_run_name, - group=self.wandb_group, - config=self.config.model_dump(), - ) + # check if zmq sidecar is open + zmq_port = None + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.config.rollout_server_url}/wandb_info" + ) as resp: + if resp.status == 200: + info = await resp.json() + zmq_port = info.get("zmq_port") + # Update project/group from API truth + if info.get("project"): + self.wandb_project = info["project"] + if info.get("group"): + self.wandb_group = info["group"] + except Exception as e: + logger.warning(f"Failed to fetch wandb_info from API: {e}") + + # logging for zmq messages + if zmq_port: + try: + from urllib.parse import urlparse + + parsed_url = urlparse(self.config.rollout_server_url) + server_host = parsed_url.hostname or "localhost" + + zmq_addr = f"tcp://{server_host}:{zmq_port}" + self.zmq_logger = ZMQLogger(address=zmq_addr) + logger.info(f"Using ZMQ Logger connected to {zmq_addr}") + + if self.wandb_project: + setup_weave_for_worker( + self.wandb_project, + group_name=self.wandb_prepend, # Pass the unique env ID as the 'group' for Weave traces + ) + if weave is not None: + try: + weave.init(self.wandb_project) + except Exception as e: + logger.warning( + f"Failed to initialize Weave in process_manager: {e}" + ) + + except Exception as e: + logger.error(f"Failed to init ZMQ Logger: {e}") + self.zmq_logger = None + + # regular wandb logs per env if zmq isnt open for some reason + if self.zmq_logger is None: + random_id = "".join(random.choices(string.ascii_lowercase, k=6)) + current_date = datetime.now().strftime("%Y-%m-%d") + wandb_run_name = f"{self.name}-{current_date}-{random_id}" + + if weave is not None: + try: + weave.init(self.wandb_project) + except Exception as e: + logger.warning(f"Failed to initialize Weave: {e}") + + wandb.init( + project=self.wandb_project, + name=wandb_run_name, + group=self.wandb_group, + config=self.config.model_dump(), + ) # Initialize the processing self.curr_step = 0 diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 18dbf35d3..180f35fc9 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -1,8 +1,12 @@ import asyncio import collections +import functools +import inspect +import os import time from abc import ABC, abstractmethod from asyncio import exceptions +from contextlib import contextmanager from typing import Literal, Optional import numpy as np @@ -11,6 +15,108 @@ from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential +try: + import weave as _weave +except Exception: + _weave = None +try: + import wandb as _wandb +except Exception: + _wandb = None + +WEAVE_ENABLED = _weave is not None and os.getenv( + "WEAVE_DISABLED", "false" +).lower() not in ("1", "true", "yes") +_weave_initialized = False + + +def _resolve_weave_project_name() -> str: + + if _wandb is not None: + try: + run = getattr(_wandb, "run", None) + project_from_run = ( + getattr(run, "project", None) if run is not None else None + ) + if project_from_run: + return str(project_from_run) + + except Exception: + pass + + project_from_env = os.getenv("WANDB_PROJECT") + if project_from_env: + return project_from_env + + project_from_weave_env = os.getenv("WEAVE_PROJECT") + if project_from_weave_env: + return project_from_weave_env + + return "atropos" + + +def ensure_weave_init() -> None: + global _weave_initialized + if WEAVE_ENABLED and not _weave_initialized: + project_name = _resolve_weave_project_name() + try: + _weave.init(project_name) + except Exception: + + pass + _weave_initialized = True + + +def weave_op(func): + if WEAVE_ENABLED: + wrapped = _weave.op(func) + if inspect.iscoroutinefunction(func): + + @functools.wraps(wrapped) + async def init_then_call(*args, **kwargs): + # Gate tracing via per-instance config if available + tracing_enabled = True + try: + if len(args) > 0 and getattr(args[0], "config", None) is not None: + tracing_enabled = getattr( + args[0].config, "tracing_enabled", True + ) + except Exception: + tracing_enabled = True + if not tracing_enabled: + return await func(*args, **kwargs) + ensure_weave_init() + return await wrapped(*args, **kwargs) + + else: + + @functools.wraps(wrapped) + def init_then_call(*args, **kwargs): + tracing_enabled = True + try: + if len(args) > 0 and getattr(args[0], "config", None) is not None: + tracing_enabled = getattr( + args[0].config, "tracing_enabled", True + ) + except Exception: + tracing_enabled = True + if not tracing_enabled: + return func(*args, **kwargs) + ensure_weave_init() + return wrapped(*args, **kwargs) + + return init_then_call + return func + + +@contextmanager +def weave_attributes(attrs: dict): + if WEAVE_ENABLED: + with _weave.attributes(attrs): + yield + else: + yield + class AsyncSemWithAdaptiveWeight(asyncio.Semaphore): def __init__(self, value: int): @@ -111,6 +217,10 @@ class ServerBaseline(BaseModel): server_type: Literal["openai", "trl", "sglang", "vllm"] = Field( default="openai", description="Type of server to use" ) + tracing_enabled: bool = Field( + default=True, + description="Enable Weave tracing for chat/completion ops (overridden by WEAVE_DISABLED).", + ) class APIServerConfig(ServerBaseline): @@ -264,6 +374,7 @@ async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion: @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) + @weave_op async def chat_completion(self, **kwargs) -> ChatCompletion: """ Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper. @@ -285,15 +396,27 @@ async def chat_completion(self, **kwargs) -> ChatCompletion: split = kwargs.pop("split", "train") stat_dict = {} stat_dict["attempts"] = 0 - if split == "train": - ret_data = await self._chat_comp(stat_dict, **kwargs) - self.request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.attempts_list.append(stat_dict["attempts"]) - else: - # Give separate eval workers, if desired, gotta go fast for those evals - ret_data = await self._chat_eval(stat_dict, **kwargs) - self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.eval_attempts_list.append(stat_dict["attempts"]) + with weave_attributes( + { + "server_type": getattr(self.config, "server_type", None), + "endpoint": "chat_completion", + "model": self.config.model_name, + "base_url": getattr(self.config, "base_url", None), + "split": split, + "n": kwargs.get("n", 1), + "wandb_group": os.getenv("WANDB_GROUP", "unknown"), + "wandb_run_id": os.getenv("WANDB_RUN_ID", None), + } + ): + if split == "train": + ret_data = await self._chat_comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + + ret_data = await self._chat_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data @retry( @@ -330,6 +453,7 @@ async def _comp_eval(self, stat_dict, **kwargs) -> Completion: stat_dict["end"] = time.time() return completions + @weave_op async def completion(self, **kwargs) -> Completion: """ Completion handler, waits for the server to be healthy and then calls the completion wrapper. @@ -352,20 +476,35 @@ async def completion(self, **kwargs) -> Completion: split = kwargs.pop("split", "train") stat_dict = {} stat_dict["attempts"] = 0 - if split == "train": - ret_data = await self._comp(stat_dict, **kwargs) - self.request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.attempts_list.append(stat_dict["attempts"]) - else: - # Give separate eval workers, if desired, gotta go fast for those evals - ret_data = await self._comp_eval(stat_dict, **kwargs) - self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.eval_attempts_list.append(stat_dict["attempts"]) + with weave_attributes( + { + "server_type": getattr(self.config, "server_type", None), + "endpoint": "completion", + "model": self.config.model_name, + "base_url": getattr(self.config, "base_url", None), + "split": split, + "n": kwargs.get("n", 1), + "wandb_group": os.getenv( + "WANDB_GROUP", "unknown" + ), # This is set during the weave setup in our base env + "wandb_run_id": os.getenv("WANDB_RUN_ID", None), + } + ): + if split == "train": + ret_data = await self._comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + + ret_data = await self._comp_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) + @weave_op async def _tokens_and_logprobs_comp( self, stat_dict, **kwargs ) -> tuple[list, list, list, list]: @@ -426,13 +565,27 @@ async def tokens_and_logprobs_completion( split = kwargs.pop("split", "train") stat_dict = {} stat_dict["attempts"] = 0 - if split == "train": - ret_data = await self._tokens_and_logprobs_comp(stat_dict, **kwargs) - self.request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.attempts_list.append(stat_dict["attempts"]) - else: - # Give separate eval workers, if desired, gotta go fast for those evals - ret_data = await self._tokens_and_logprobs_comp_eval(stat_dict, **kwargs) - self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.eval_attempts_list.append(stat_dict["attempts"]) + with weave_attributes( + { + "server_type": getattr(self.config, "server_type", None), + "endpoint": "tokens_and_logprobs_completion", + "model": self.config.model_name, + "base_url": getattr(self.config, "base_url", None), + "split": split, + "n": kwargs.get("n", 1), + "wandb_group": os.getenv("WANDB_GROUP", "unknown"), + "wandb_run_id": os.getenv("WANDB_RUN_ID", None), + } + ): + if split == "train": + ret_data = await self._tokens_and_logprobs_comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + + ret_data = await self._tokens_and_logprobs_comp_eval( + stat_dict, **kwargs + ) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data diff --git a/atroposlib/tests/test_sidecar_aggregation.py b/atroposlib/tests/test_sidecar_aggregation.py new file mode 100644 index 000000000..87d18bc27 --- /dev/null +++ b/atroposlib/tests/test_sidecar_aggregation.py @@ -0,0 +1,380 @@ +"""Tests for the ZMQ sidecar aggregation and leader routing.""" + +import time +from collections import defaultdict +from unittest.mock import MagicMock, patch + +import pytest +import zmq + + +class TestZMQLogReceiver: + """Tests for the ZMQLogReceiver class.""" + + def test_receiver_init_and_bind(self): + """Test that receiver initializes and binds to port.""" + from atroposlib.utils.logging_client import ZMQLogReceiver + + context = zmq.Context() + receiver = ZMQLogReceiver(port=5700, context=context) + assert receiver.port == 5700 + receiver.close() + context.term() + + def test_receiver_recv_nowait_empty(self): + """Test that recv_nowait returns None when no data.""" + from atroposlib.utils.logging_client import ZMQLogReceiver + + context = zmq.Context() + receiver = ZMQLogReceiver(port=5701, context=context) + result = receiver.recv_nowait() + assert result is None + receiver.close() + context.term() + + def test_receiver_recv_data(self): + """Test that receiver can receive data from a PUSH socket.""" + from atroposlib.utils.logging_client import ZMQLogReceiver + + context = zmq.Context() + receiver = ZMQLogReceiver(port=5702, context=context) + + sender = context.socket(zmq.PUSH) + sender.connect("tcp://localhost:5702") + time.sleep(0.1) + + test_data = {"metric": 1.0, "_step": 10} + sender.send_pyobj(test_data) + time.sleep(0.1) + + result = receiver.recv_nowait() + assert result == test_data + + sender.close() + receiver.close() + context.term() + + +class TestZMQLogAggregator: + """Tests for the ZMQLogAggregator class.""" + + def test_aggregator_env_registration(self): + """Test environment registration.""" + from atroposlib.api.sidecar import ZMQLogAggregator + + aggregator = ZMQLogAggregator(port=5710) + aggregator._handle_control_message( + { + "_type": "env_register", + "env_type": "math", + "instance": "math_0", + "is_leader": True, + "leader_receive_port": 5800, + } + ) + assert "math_0" in aggregator.registered_envs["math"] + assert "math" in aggregator.leaders + + def test_aggregator_env_disconnect(self): + """Test environment disconnection.""" + from atroposlib.api.sidecar import ZMQLogAggregator + + aggregator = ZMQLogAggregator(port=5711) + aggregator.registered_envs["math"].add("math_0") + aggregator.registered_envs["math"].add("math_1") + + aggregator._handle_control_message( + { + "_type": "env_disconnect", + "env_type": "math", + "instance": "math_0", + "was_leader": False, + } + ) + assert "math_0" not in aggregator.registered_envs["math"] + assert "math_1" in aggregator.registered_envs["math"] + + def test_aggregator_all_reported(self): + """Test _all_reported logic.""" + from atroposlib.api.sidecar import ZMQLogAggregator + + aggregator = ZMQLogAggregator(port=5712) + aggregator.registered_envs["math"] = {"math_0", "math_1", "math_2"} + + key = (10, "math") + aggregator.env_reported[key] = {"math_0", "math_1"} + assert not aggregator._all_reported(key) + + aggregator.env_reported[key].add("math_2") + assert aggregator._all_reported(key) + + def test_aggregator_handle_log_payload(self): + """Test handling log payloads.""" + from atroposlib.api.sidecar import ZMQLogAggregator + + aggregator = ZMQLogAggregator(port=5713) + aggregator.registered_envs["math"] = {"math_0", "math_1"} + + aggregator._handle_log_payload( + { + "_step": 10, + "_env_type": "math", + "_instance": "math_0", + "accuracy": 0.8, + } + ) + + key = (10, "math") + assert key in aggregator.pending_metrics + assert "math_0" in aggregator.env_reported[key] + + def test_aggregator_aggregation(self): + """Test metric aggregation.""" + from atroposlib.api.sidecar import ZMQLogAggregator + + aggregator = ZMQLogAggregator(port=5714) + aggregator.registered_envs["math"] = {"math_0", "math_1"} + + # Create a mock leader socket + mock_socket = MagicMock() + aggregator.leaders["math"] = {"port": 5800, "socket": mock_socket} + + # Send logs from both instances + aggregator._handle_log_payload( + { + "_step": 10, + "_env_type": "math", + "_instance": "math_0", + "accuracy": 0.8, + } + ) + aggregator._handle_log_payload( + { + "_step": 10, + "_env_type": "math", + "_instance": "math_1", + "accuracy": 0.9, + } + ) + + # Verify aggregation was triggered (socket.send_pyobj was called) + mock_socket.send_pyobj.assert_called_once() + call_args = mock_socket.send_pyobj.call_args + sent_data = call_args[0][0] + + assert sent_data["_step"] == 10 + assert "math/instances/math_0/accuracy" in sent_data + assert "math/instances/math_1/accuracy" in sent_data + assert sent_data["math/aggregated/accuracy_mean"] == pytest.approx(0.85) + + def test_aggregator_timeout(self): + """Test timeout handling for slow instances.""" + from atroposlib.api.sidecar import AGGREGATION_TIMEOUT, ZMQLogAggregator + + aggregator = ZMQLogAggregator(port=5715) + aggregator.registered_envs["math"] = {"math_0", "math_1"} + + # Create a mock leader socket + mock_socket = MagicMock() + aggregator.leaders["math"] = {"port": 5800, "socket": mock_socket} + + # Only one instance reports + aggregator._handle_log_payload( + { + "_step": 10, + "_env_type": "math", + "_instance": "math_0", + "accuracy": 0.8, + } + ) + + key = (10, "math") + aggregator.pending_timestamps[key] = time.time() - AGGREGATION_TIMEOUT - 1 + + aggregator._check_timeouts() + + # Verify partial data was sent + mock_socket.send_pyobj.assert_called_once() + + def test_aggregator_multiple_env_types(self): + """Test aggregation with multiple environment types.""" + from atroposlib.api.sidecar import ZMQLogAggregator + + aggregator = ZMQLogAggregator(port=5716) + aggregator.registered_envs["math"] = {"math_0"} + aggregator.registered_envs["crossword"] = {"crossword_0"} + + mock_math_socket = MagicMock() + mock_crossword_socket = MagicMock() + aggregator.leaders["math"] = {"port": 5800, "socket": mock_math_socket} + aggregator.leaders["crossword"] = { + "port": 5801, + "socket": mock_crossword_socket, + } + + aggregator._handle_log_payload( + { + "_step": 10, + "_env_type": "math", + "_instance": "math_0", + "accuracy": 0.8, + } + ) + aggregator._handle_log_payload( + { + "_step": 10, + "_env_type": "crossword", + "_instance": "crossword_0", + "accuracy": 0.9, + } + ) + + # Both should be sent to their respective leaders + mock_math_socket.send_pyobj.assert_called_once() + mock_crossword_socket.send_pyobj.assert_called_once() + + math_data = mock_math_socket.send_pyobj.call_args[0][0] + crossword_data = mock_crossword_socket.send_pyobj.call_args[0][0] + + assert "math/instances/math_0/accuracy" in math_data + assert "crossword/instances/crossword_0/accuracy" in crossword_data + + +class TestZMQLogger: + """Tests for the ZMQLogger class.""" + + def test_logger_sends_data(self): + """Test that ZMQLogger sends data with correct metadata.""" + from atroposlib.utils.logging_client import ZMQLogger + + context = zmq.Context() + receiver = context.socket(zmq.PULL) + receiver.bind("tcp://*:5720") + + logger = ZMQLogger(address="tcp://localhost:5720", context=context) + time.sleep(0.1) + + logger.log( + {"accuracy": 0.8}, + step=10, + env_type="math", + instance_name="math_0", + ) + time.sleep(0.1) + + data = receiver.recv_pyobj(flags=zmq.NOBLOCK) + assert data["_step"] == 10 + assert data["_env_type"] == "math" + assert data["_instance"] == "math_0" + assert data["accuracy"] == 0.8 + + logger.close() + receiver.close() + context.term() + + +class TestLeaderElection: + """Tests for leader election in server.py.""" + + @pytest.fixture + def app_state(self): + """Create a mock app state.""" + + class MockAppState: + started = True + envs = [] + env_leaders = {} + next_leader_port = 5600 + status_dict = {"step": 0} + save_checkpoint_interval = 100 + num_steps = 1000 + project = "test-project" + group = "test-group" + checkpoint_dir = "/tmp/checkpoints" + + return MockAppState() + + def test_first_instance_becomes_leader(self, app_state): + """Test that first instance of an env_type becomes leader.""" + # Simulate the logic from register_env endpoint + desired_name = "math" + instance_index = len( + [x for x in app_state.envs if x["desired_name"] == desired_name] + ) + real_name = f"{desired_name}_{instance_index}" + registered_id = len(app_state.envs) + + is_leader = desired_name not in app_state.env_leaders + leader_receive_port = None + + if is_leader: + leader_receive_port = app_state.next_leader_port + app_state.next_leader_port += 1 + app_state.env_leaders[desired_name] = { + "instance": real_name, + "env_id": registered_id, + "receive_port": leader_receive_port, + } + + assert is_leader is True + assert leader_receive_port == 5600 + assert "math" in app_state.env_leaders + + def test_second_instance_not_leader(self, app_state): + """Test that second instance is not a leader.""" + # First instance + app_state.env_leaders["math"] = { + "instance": "math_0", + "env_id": 0, + "receive_port": 5600, + } + app_state.envs.append({"desired_name": "math", "real_name": "math_0"}) + + # Second instance + desired_name = "math" + instance_index = len( + [x for x in app_state.envs if x["desired_name"] == desired_name] + ) + real_name = f"{desired_name}_{instance_index}" + + is_leader = desired_name not in app_state.env_leaders + leader_receive_port = None + + assert is_leader is False + assert leader_receive_port is None + + def test_different_env_types_have_own_leaders(self, app_state): + """Test that different env_types each get their own leader.""" + # math leader + app_state.env_leaders["math"] = { + "instance": "math_0", + "env_id": 0, + "receive_port": 5600, + } + app_state.envs.append({"desired_name": "math", "real_name": "math_0"}) + app_state.next_leader_port = 5601 + + # crossword - should get its own leader + desired_name = "crossword" + instance_index = len( + [x for x in app_state.envs if x["desired_name"] == desired_name] + ) + real_name = f"{desired_name}_{instance_index}" + registered_id = len(app_state.envs) + + is_leader = desired_name not in app_state.env_leaders + leader_receive_port = None + + if is_leader: + leader_receive_port = app_state.next_leader_port + app_state.next_leader_port += 1 + app_state.env_leaders[desired_name] = { + "instance": real_name, + "env_id": registered_id, + "receive_port": leader_receive_port, + } + + assert is_leader is True + assert leader_receive_port == 5601 + assert "math" in app_state.env_leaders + assert "crossword" in app_state.env_leaders diff --git a/atroposlib/utils/logging_client.py b/atroposlib/utils/logging_client.py new file mode 100644 index 000000000..633de8212 --- /dev/null +++ b/atroposlib/utils/logging_client.py @@ -0,0 +1,121 @@ +import logging +import os +from typing import Any, Dict, Optional + +import zmq + +logger = logging.getLogger(__name__) + + +class ZMQLogger: + """ + A client for the ZMQLogAggregator. Replaces local wandb.log calls + by pushing data to the central API server. + """ + + def __init__(self, address: str, context: Optional[zmq.Context] = None): + """ + Args: + address: Full ZMQ address ("tcp://1.2.3.4:5555") + context: Optional existing ZMQ context + """ + self.context = context or zmq.Context() + self.socket = self.context.socket(zmq.PUSH) + self.socket.setsockopt(zmq.SNDHWM, 10000) + self.socket.setsockopt(zmq.LINGER, 1000) + + logger.info(f"Connecting ZMQLogger to {address}") + self.socket.connect(address) + + def log( + self, + data: Dict[str, Any], + step: Optional[int] = None, + env_type: Optional[str] = None, + instance_name: Optional[str] = None, + commit: Optional[bool] = None, + ): + """ + Send log data to the central server. + + Args: + data: Dictionary of metrics to log + step: Optional step number + env_type: Environment type / name for aggregation (math) + instance_name: Instance number (math_1) + commit: Optional commit flag (wandb.log compatibility) + """ + if step is not None: + data["_step"] = step + if env_type is not None: + data["_env_type"] = env_type + if instance_name is not None: + data["_instance"] = instance_name + + try: + self.socket.send_pyobj(data, flags=zmq.NOBLOCK) + except zmq.Again: + logger.warning("ZMQLogger buffer full, dropping log packet") + except Exception as e: + logger.error(f"Failed to send log data: {e}") + + def close(self): + self.socket.close() + + +class ZMQLogReceiver: + """ + A receiver for aggregated log data. Used by leader instances to receive + aggregated metrics from the sidecar and log them to wandb. + """ + + def __init__(self, port: int, context: Optional[zmq.Context] = None): + """ + Args: + port: Port to bind to for receiving aggregated data + context: Optional existing ZMQ context + """ + self.port = port + self.context = context or zmq.Context() + self.socket = self.context.socket(zmq.PULL) + self.socket.setsockopt(zmq.RCVHWM, 10000) + self.socket.bind(f"tcp://*:{port}") + self.running = False + logger.info(f"ZMQLogReceiver bound to port {port}") + + def recv_nowait(self) -> Optional[Dict[str, Any]]: + """ + Non-blocking receive of aggregated data. + + Returns: + Dictionary of aggregated metrics if available, None otherwise + """ + try: + return self.socket.recv_pyobj(flags=zmq.NOBLOCK) + except zmq.Again: + return None + except Exception as e: + logger.error(f"Failed to receive log data: {e}") + return None + + def close(self): + self.socket.close() + + +def setup_weave_for_worker( + project_name: str, group_name: Optional[str] = None, run_id: Optional[str] = None +): + """ + Configure environment variables so Weave uses the correct project, + even if wandb.init() is not called locally. + """ + if project_name: + os.environ["WEAVE_PROJECT"] = project_name + os.environ["WANDB_PROJECT"] = project_name + + if group_name: + os.environ["WANDB_GROUP"] = group_name + + if run_id: + # Weave often respects WANDB_RUN_ID to associate calls with a specific run + os.environ["WANDB_RUN_ID"] = run_id diff --git a/pyproject.toml b/pyproject.toml index dd3841eeb..bc460b174 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "markdown", "numpy", "wandb", + "weave", "gymnasium", "math-verify==0.7.0", "jinja2", @@ -26,6 +27,7 @@ dependencies = [ "jsonlines", "pydantic-cli", "hf_transfer", + "pyzmq", ] [project.scripts]