diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index e37688ba3..709637b17 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -1,7 +1,13 @@ +import asyncio +import os import time import uuid +from contextlib import suppress from typing import Any, Dict, List, Optional +import wandb +import zmq +import zmq.asyncio from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import PlainTextResponse @@ -18,6 +24,9 @@ 0.01 # Minimum weight to prevent environments from being completely starved ) +MESSAGE_BUS_ENABLED = os.getenv("ATROPOS_ENABLE_MESSAGE_BUS", "1") != "0" +MESSAGE_BUS_ENDPOINT = os.getenv("ATROPOS_MESSAGE_BUS_ENDPOINT", "tcp://0.0.0.0:5759") + # Message import removed - using Dict[str, Any] for more flexible validation app = FastAPI(title="AtroposLib API") @@ -114,6 +123,108 @@ class Info(BaseModel): batch_size: int = -1 +async def _log_metrics(message: Dict[str, Any]) -> None: + if not getattr(app.state, "wandb_enabled", False): + return + if getattr(app.state, "wandb_run", None) is None: + return + + wandb_prepend = message.get("wandb_prepend") + metrics = message.get("metrics") or {} + server_metrics = message.get("server_metrics") or {} + rollouts = message.get("rollouts") or [] + step = message.get("step") + + metrics_to_log: Dict[str, Any] = {} + if wandb_prepend: + metrics_to_log.update({f"{wandb_prepend}_{k}": v for k, v in metrics.items()}) + else: + metrics_to_log.update(metrics) + + metrics_to_log.update(server_metrics) + + if rollouts: + table = wandb.Table(columns=["text", "score"]) + for entry in rollouts: + if isinstance(entry, list): + for text, score in entry: + table.add_data(text, score) + else: + text, score = entry + table.add_data(text, score) + table_key = "train/rollouts" + if wandb_prepend: + table_key = f"{wandb_prepend}_{table_key}" + metrics_to_log[table_key] = table + + async with app.state.wandb_lock: # type: ignore[attr-defined] + await asyncio.to_thread(wandb.log, metrics_to_log, step=step) + + +async def _message_bus_worker() -> None: + socket = getattr(app.state, "message_bus_socket", None) + if socket is None: + return + + while True: + try: + message = await socket.recv_json() + except asyncio.CancelledError: + break + except Exception: + continue + + token = message.get("token") + if token is None: + continue + env_record = getattr(app.state, "message_bus_tokens", {}).get(token) + if env_record is None: + continue + + msg_type = message.get("type") + if msg_type == "metrics": + await _log_metrics(message) + + +@app.on_event("startup") +async def startup_event() -> None: + if MESSAGE_BUS_ENABLED: + context = zmq.asyncio.Context.instance() + socket = context.socket(zmq.PULL) + socket.setsockopt(zmq.LINGER, 0) + socket.setsockopt(zmq.RCVHWM, 0) + socket.bind(MESSAGE_BUS_ENDPOINT) + app.state.message_bus_context = context + app.state.message_bus_socket = socket + app.state.message_bus_tokens = {} + app.state.message_bus_endpoint = MESSAGE_BUS_ENDPOINT + app.state.message_bus_task = asyncio.create_task(_message_bus_worker()) + + app.state.wandb_run = None + app.state.wandb_enabled = False + app.state.wandb_lock = asyncio.Lock() + + +@app.on_event("shutdown") +async def shutdown_event() -> None: + message_bus_task = getattr(app.state, "message_bus_task", None) + if message_bus_task: + message_bus_task.cancel() + with suppress(Exception): + await message_bus_task + socket = getattr(app.state, "message_bus_socket", None) + if socket is not None: + socket.close(linger=0) + app.state.message_bus_socket = None + if getattr(app.state, "wandb_run", None) is not None: + wandb.finish() + app.state.wandb_run = None + if getattr(app.state, "message_bus_context", None) is not None: + context = app.state.message_bus_context + context.term() + app.state.message_bus_context = None + + @app.post("/register") async def register(registration: Registration): # Initialize app state if not already done @@ -131,6 +242,26 @@ async def register(registration: Registration): app.state.started = False app.state.envs = [] app.state.buffer = {} # Buffer for mixed-size groups per environment + if MESSAGE_BUS_ENABLED: + app.state.wandb_enabled = True + if getattr(app.state, "wandb_run", None) is not None: + wandb.finish() + wandb_mode = os.getenv("WANDB_MODE") + if wandb_mode is None and not os.getenv("WANDB_API_KEY"): + wandb_mode = "disabled" + init_kwargs: Dict[str, Any] = { + "project": registration.wandb_project, + "group": registration.wandb_group, + "config": { + "batch_size": registration.batch_size, + "max_token_len": registration.max_token_len, + "num_steps": registration.num_steps, + }, + "settings": wandb.Settings(start_method="thread"), + } + if wandb_mode is not None: + init_kwargs["mode"] = wandb_mode + app.state.wandb_run = wandb.init(**init_kwargs) # Initialize requesters list if not already done if not hasattr(app.state, "requesters"): @@ -172,7 +303,7 @@ async def register_env_url(register_env: RegisterEnv): "group_size": register_env.group_size, } ) - return { + response = { "status": "success", "env_id": registered_id, "wandb_name": real_name, @@ -181,12 +312,35 @@ async def register_env_url(register_env: RegisterEnv): "checkpoint_interval": app.state.save_checkpoint_interval, "num_steps": app.state.num_steps, } + if ( + MESSAGE_BUS_ENABLED + and getattr(app.state, "message_bus_endpoint", None) is not None + ): + token = uuid.uuid4().hex + app.state.envs[registered_id]["message_token"] = token + if getattr(app.state, "message_bus_tokens", None) is None: + app.state.message_bus_tokens = {} + app.state.message_bus_tokens[token] = { + "registered_id": registered_id, + "desired_name": register_env.desired_name, + "real_name": real_name, + } + response["message_bus"] = { + "endpoint": app.state.message_bus_endpoint, + "token": token, + "env_name": register_env.desired_name, + "wandb_prepend": real_name, + } + return response @app.post("/disconnect-env") async def disconnect_env(disconnect_env: EnvIdentifier): try: app.state.envs[disconnect_env.env_id]["connected"] = False + token = app.state.envs[disconnect_env.env_id].get("message_token") + if token and getattr(app.state, "message_bus_tokens", None) is not None: + app.state.message_bus_tokens.pop(token, None) return {"status": "success"} except (AttributeError, IndexError) as e: return {"status": "failure", "error": str(e)} @@ -533,6 +687,15 @@ async def reset_data(): app.state.requesters = [] app.state.envs = [] app.state.buffer = {} + if getattr(app.state, "message_bus_tokens", None) is not None: + app.state.message_bus_tokens.clear() except KeyError: pass + if getattr(app.state, "wandb_run", None) is not None: + wandb.finish() + app.state.wandb_run = None + app.state.wandb_enabled = ( + MESSAGE_BUS_ENABLED + and getattr(app.state, "message_bus_socket", None) is not None + ) return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index bf6491f9b..33db11237 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -40,6 +40,7 @@ from atroposlib.utils.metrics import get_std_min_max_avg from ..type_definitions import Item, Message +from ..utils.message_bus import MessageBusClient from .server_handling.server_manager import ( APIServer, APIServerConfig, @@ -224,6 +225,10 @@ def __init__( self.wandb_prepend = None self.checkpoint_dir = "" self.checkpoint_interval = -1 + self.message_bus_details: Optional[Dict[str, Any]] = None + self.message_bus_client: Optional[MessageBusClient] = None + self.message_bus_env_name: Optional[str] = None + self.message_bus_wandb_prepend: Optional[str] = None if self.config.data_path_to_save_groups is not None: Path(self.config.data_path_to_save_groups).parent.mkdir( parents=True, exist_ok=True @@ -423,36 +428,39 @@ async def setup(self): raise NotImplementedError("Setup method must be implemented in subclass") async def setup_wandb(self): - if self.config.use_wandb: - # Setup wandb getting the group and project via the server - while self.wandb_project is None: - async with aiohttp.ClientSession() as session: - async with session.get( - f"{self.config.rollout_server_url}/wandb_info" - ) as resp: - data = await parse_http_response(resp, logger) - self.wandb_group = data["group"] - self.wandb_project = data["project"] - - if self.wandb_project is None: - await asyncio.sleep(1) - continue - - wandb_run_name = None - if self.config.wandb_name: - random_id = "".join(random.choices(string.ascii_lowercase, k=6)) - current_date = datetime.now().strftime("%Y-%m-%d") - wandb_run_name = ( - f"{self.config.wandb_name}-{current_date}-{random_id}" - ) + if not self.config.use_wandb: + return - wandb.init( - name=wandb_run_name, - project=self.wandb_project, - group=self.wandb_group, - config=self.config.model_dump(), - ) - break + # Setup wandb getting the group and project via the server + while self.wandb_project is None: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.config.rollout_server_url}/wandb_info" + ) as resp: + data = await parse_http_response(resp, logger) + self.wandb_group = data["group"] + self.wandb_project = data["project"] + + if self.wandb_project is None: + await asyncio.sleep(1) + continue + + # When a central logger is present we avoid creating per-env wandb runs. + if self.message_bus_client is not None: + return + + wandb_run_name = None + if self.config.wandb_name: + random_id = "".join(random.choices(string.ascii_lowercase, k=6)) + current_date = datetime.now().strftime("%Y-%m-%d") + wandb_run_name = f"{self.config.wandb_name}-{current_date}-{random_id}" + + wandb.init( + name=wandb_run_name, + project=self.wandb_project, + group=self.wandb_group, + config=self.config.model_dump(), + ) @retry( stop=stop_after_attempt(3), @@ -496,6 +504,14 @@ async def register_env(self): self.config.total_steps = data["num_steps"] if self.config.total_steps == -1: raise ValueError("Total steps not set in config or server!") + self.message_bus_details = data.get("message_bus") + if self.message_bus_details: + self.message_bus_env_name = self.message_bus_details.get( + "env_name", self.config.wandb_name + ) + self.message_bus_wandb_prepend = self.message_bus_details.get( + "wandb_prepend" + ) print( f"Initialized env with id {self.env_id}: " f"curr_step: {self.curr_step}, " @@ -506,6 +522,33 @@ async def register_env(self): self.load_checkpoint() break + async def setup_message_bus(self): + if self.message_bus_client is not None: + return + if not self.message_bus_details: + return + + endpoint = self.message_bus_details.get("endpoint") + token = self.message_bus_details.get("token") + if not endpoint or not token: + logger.warning( + "Message bus details missing endpoint or token, skipping connection." + ) + return + try: + self.message_bus_client = MessageBusClient(endpoint=endpoint, token=token) + except Exception as exc: + logger.warning(f"Failed to initialise message bus client: {exc}") + self.message_bus_client = None + + async def close_message_bus(self): + if self.message_bus_client is None: + return + try: + await self.message_bus_client.close() + finally: + self.message_bus_client = None + async def get_server_info(self): """ Get the server info @@ -567,6 +610,18 @@ def perf_stats(self, metrics_dict): self.workers_added_list = list() return metrics_dict + def _sanitize_for_json(self, value: Any) -> Any: + """Convert numpy/native structures into JSON-serialisable primitives.""" + if isinstance(value, np.generic): + return value.item() + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, dict): + return {k: self._sanitize_for_json(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [self._sanitize_for_json(v) for v in value] + return value + async def create_rollout_table(self, wandb_metrics): if len(self.rollouts_for_wandb) > 0: table = wandb.Table(columns=["text", "score"]) @@ -610,8 +665,11 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None): """ if wandb_metrics is None: wandb_metrics = dict() + server_wandb_metrics: Dict[str, Any] = {} for i, server in enumerate(self.server.servers): - server_wandb_metrics = await server.wandb_metrics({}, f"server_{i}") + server_wandb_metrics = await server.wandb_metrics( + server_wandb_metrics, f"server_{i}" + ) if len(self.completion_lengths) > 0: wandb_metrics["train/completion_lengths"] = sum( self.completion_lengths @@ -628,18 +686,48 @@ def wandb_log(self, wandb_metrics: Optional[Dict] = None): wandb_metrics["train/completion_lengths_p95"] = ( np.array(self.completion_lengths) > (0.95 * self.max_token_len) ).mean() - wandb_metrics = await self.create_rollout_table(wandb_metrics) + + use_message_bus = self.message_bus_client is not None and self.config.use_wandb + if not use_message_bus: + wandb_metrics = await self.create_rollout_table(wandb_metrics) wandb_metrics = self.perf_stats(wandb_metrics) + + rollout_snapshot: List[List[Tuple[str, float]]] = [] + if use_message_bus: + rollout_snapshot = [list(group) for group in self.rollouts_for_wandb] self.rollouts_for_wandb = [] self.completion_lengths = [] - if self.config.use_wandb: - if self.wandb_prepend is not None: - wandb_metrics = { - f"{self.wandb_prepend}_{k}": v for k, v in wandb_metrics.items() - } - # add server metrics to wandb without prepend to collate them all - wandb_metrics.update(server_wandb_metrics) - wandb.log(wandb_metrics, step=self.curr_step) + if not self.config.use_wandb: + return + + if use_message_bus and self.message_bus_client is not None: + payload = { + "type": "metrics", + "env_id": self.env_id, + "env_name": self.message_bus_env_name or self.config.wandb_name, + "wandb_prepend": self.message_bus_wandb_prepend or self.wandb_prepend, + "metrics": { + k: self._sanitize_for_json(v) for k, v in wandb_metrics.items() + }, + "server_metrics": { + k: self._sanitize_for_json(v) + for k, v in server_wandb_metrics.items() + }, + "rollouts": self._sanitize_for_json(rollout_snapshot), + "step": self.curr_step, + } + try: + await self.message_bus_client.send_json(payload) + except Exception as exc: + logger.warning(f"Failed to send metrics to message bus: {exc}") + return + + if self.wandb_prepend is not None: + wandb_metrics = { + f"{self.wandb_prepend}_{k}": v for k, v in wandb_metrics.items() + } + wandb_metrics.update(server_wandb_metrics) + wandb.log(wandb_metrics, step=self.curr_step) async def evaluate_log( self, @@ -1106,9 +1194,10 @@ async def env_manager(self): Rollout manager """ await self.setup() - await self.setup_wandb() await self.register_env() await self.get_server_info() + await self.setup_message_bus() + await self.setup_wandb() # Wait for other instances to get setup :) await asyncio.sleep(5) while True: @@ -1185,6 +1274,7 @@ async def env_manager(self): # Do we want to retry? probably not... # self.backlog.append(item["item"]) await asyncio.sleep(0.1) + await self.close_message_bus() async def process_manager(self): """ diff --git a/atroposlib/tests/test_api_messages_handling.py b/atroposlib/tests/test_api_messages_handling.py index 0f7ac9227..cf026627e 100644 --- a/atroposlib/tests/test_api_messages_handling.py +++ b/atroposlib/tests/test_api_messages_handling.py @@ -4,29 +4,45 @@ import os import signal +import socket import subprocess import time +from subprocess import TimeoutExpired import pytest import requests -def wait_for_api_server(max_wait=10): +def wait_for_api_server(max_wait: float = 30.0, interval: float = 0.2) -> bool: """Wait for API server to be ready.""" - for _ in range(max_wait): + attempts = max(1, int(max_wait / interval)) + for attempt in range(attempts): try: response = requests.get("http://localhost:8000/info") if response.status_code == 200: return True - except requests.exceptions.ConnectionError: - pass - time.sleep(1) + except requests.exceptions.ConnectionError as exc: + print(f"Waiting for API server (attempt {attempt + 1}/{attempts}): {exc}") + time.sleep(interval) return False +def _get_free_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + @pytest.fixture(scope="module") def api_server(): """Launch API server for testing.""" + prev_endpoint = os.environ.get("ATROPOS_MESSAGE_BUS_ENDPOINT") + prev_enable = os.environ.get("ATROPOS_ENABLE_MESSAGE_BUS") + bus_port = _get_free_tcp_port() + os.environ["ATROPOS_MESSAGE_BUS_ENDPOINT"] = f"tcp://127.0.0.1:{bus_port}" + if prev_enable is None: + os.environ["ATROPOS_ENABLE_MESSAGE_BUS"] = "1" + # Start the API server as a subprocess proc = subprocess.Popen( [ @@ -46,7 +62,23 @@ def api_server(): # Wait for server to be ready if not wait_for_api_server(): proc.terminate() - raise RuntimeError("API server failed to start") + try: + stdout, stderr = proc.communicate(timeout=1) + except TimeoutExpired: + proc.kill() + stdout, stderr = proc.communicate() + stderr = stderr.decode() if stderr else "" + stdout = stdout.decode() if stdout else "" + msg = f"API server failed to start. stdout:\n{stdout}\nstderr:\n{stderr}" + if prev_endpoint is None: + os.environ.pop("ATROPOS_MESSAGE_BUS_ENDPOINT", None) + else: + os.environ["ATROPOS_MESSAGE_BUS_ENDPOINT"] = prev_endpoint + if prev_enable is None: + os.environ.pop("ATROPOS_ENABLE_MESSAGE_BUS", None) + else: + os.environ["ATROPOS_ENABLE_MESSAGE_BUS"] = prev_enable + raise RuntimeError(msg) yield @@ -60,6 +92,15 @@ def api_server(): except Exception: pass + if prev_endpoint is None: + os.environ.pop("ATROPOS_MESSAGE_BUS_ENDPOINT", None) + else: + os.environ["ATROPOS_MESSAGE_BUS_ENDPOINT"] = prev_endpoint + if prev_enable is None: + os.environ.pop("ATROPOS_ENABLE_MESSAGE_BUS", None) + else: + os.environ["ATROPOS_ENABLE_MESSAGE_BUS"] = prev_enable + @pytest.fixture(autouse=True) def reset_api_state(): diff --git a/atroposlib/tests/test_message_bus.py b/atroposlib/tests/test_message_bus.py new file mode 100644 index 000000000..3daddc0bc --- /dev/null +++ b/atroposlib/tests/test_message_bus.py @@ -0,0 +1,375 @@ +import asyncio +import socket +import sys +import time +import types +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pytest +from fastapi.testclient import TestClient + +try: + import wandb # type: ignore +except ModuleNotFoundError: # pragma: no cover + wandb = types.SimpleNamespace( + Settings=lambda **kwargs: SimpleNamespace(**kwargs), + Table=lambda *args, **kwargs: None, + init=lambda **kwargs: None, + log=lambda *args, **kwargs: None, + finish=lambda *args, **kwargs: None, + ) + sys.modules["wandb"] = wandb + +from atroposlib.api import server +from atroposlib.envs.base import BaseEnv +from atroposlib.utils.message_bus import MessageBusClient + + +def _get_free_tcp_port() -> int: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +class DummyTable: + def __init__(self, columns: List[str]): + self.columns = columns + self.rows: List[Tuple[Any, ...]] = [] + + def add_data(self, *values: Any) -> None: + self.rows.append(values) + + +def test_api_message_bus_routes_metrics_to_wandb(monkeypatch): + port = _get_free_tcp_port() + endpoint = f"tcp://127.0.0.1:{port}" + + monkeypatch.setattr(server, "MESSAGE_BUS_ENABLED", True) + monkeypatch.setattr(server, "MESSAGE_BUS_ENDPOINT", endpoint) + + init_calls: List[Dict[str, Any]] = [] + logged_calls: List[Tuple[Dict[str, Any], Any]] = [] + + monkeypatch.setattr(server.wandb, "Table", DummyTable) + monkeypatch.setattr( + server.wandb, + "init", + lambda **kwargs: init_calls.append(kwargs) or object(), + ) + monkeypatch.setattr( + server.wandb, + "log", + lambda metrics, step=None: logged_calls.append((metrics, step)), + ) + monkeypatch.setattr(server.wandb, "finish", lambda *args, **kwargs: None) + + with TestClient(server.app) as client: + response = client.post( + "/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 4, + "max_token_len": 128, + "checkpoint_dir": "/tmp", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": 100, + }, + ) + assert response.status_code == 200 + assert init_calls, "wandb.init should be invoked during trainer registration" + + server.app.state.started = True + + env_response = client.post( + "/register-env", + json={ + "max_token_length": 128, + "desired_name": "env", + "weight": 1.0, + "group_size": 2, + "min_batch_allocation": None, + }, + ) + data = env_response.json() + assert env_response.status_code == 200 + message_bus = data.get("message_bus") + assert message_bus, "message bus details should be returned when enabled" + assert message_bus["endpoint"] == endpoint + assert message_bus["token"] in server.app.state.message_bus_tokens + + payload = { + "type": "metrics", + "token": message_bus["token"], + "metrics": {"train/foo": 1.0}, + "server_metrics": {"server_metric": 2.0}, + "rollouts": [[("hello world", 0.5)]], + "step": 5, + "wandb_prepend": message_bus["wandb_prepend"], + } + asyncio.run(server._log_metrics(payload)) + + assert logged_calls, "Metrics sent over the message bus should reach wandb.log" + metrics_logged, step_logged = logged_calls[-1] + expected_prefix = message_bus["wandb_prepend"] + assert metrics_logged[f"{expected_prefix}_train/foo"] == 1.0 + assert metrics_logged["server_metric"] == 2.0 + rollout_key = f"{expected_prefix}_train/rollouts" + assert rollout_key in metrics_logged + assert metrics_logged[rollout_key].rows[0] == ("hello world", 0.5) + assert step_logged == 5 + + client.get("/reset_data") + + +@pytest.mark.asyncio +async def test_message_bus_worker_processes_registered_messages(monkeypatch): + token = "tok123" + messages: List[Dict[str, Any]] = [ + {"token": token, "type": "metrics", "payload": 1}, + ] + + class DummySocket: + def __init__(self, responses: List[Dict[str, Any]]): + self._responses = list(responses) + self.closed = False + + async def recv_json(self): + if self._responses: + return self._responses.pop(0) + raise asyncio.CancelledError() + + def close(self, linger: int = 0): + self.closed = True + + def setsockopt(self, *args, **kwargs): + pass + + socket = DummySocket(messages) + original_socket = getattr(server.app.state, "message_bus_socket", None) + original_tokens = getattr(server.app.state, "message_bus_tokens", None) + server.app.state.message_bus_socket = socket + server.app.state.message_bus_tokens = {token: {"registered_id": 0}} + + logged_messages: List[Dict[str, Any]] = [] + + async def fake_log(message: Dict[str, Any]) -> None: + logged_messages.append(message) + + monkeypatch.setattr(server, "_log_metrics", fake_log) + + worker_task = asyncio.create_task(server._message_bus_worker()) + await asyncio.sleep(0) + await worker_task + + assert logged_messages and logged_messages[0]["payload"] == 1 + assert socket.closed is False + server.app.state.message_bus_socket = original_socket + server.app.state.message_bus_tokens = original_tokens + + +def test_message_bus_end_to_end(monkeypatch): + port = _get_free_tcp_port() + endpoint = f"tcp://127.0.0.1:{port}" + + monkeypatch.setattr(server, "MESSAGE_BUS_ENABLED", True) + monkeypatch.setattr(server, "MESSAGE_BUS_ENDPOINT", endpoint) + + init_calls: List[Dict[str, Any]] = [] + logged_calls: List[Tuple[Dict[str, Any], Any]] = [] + + monkeypatch.setattr(server.wandb, "Table", DummyTable) + monkeypatch.setattr( + server.wandb, + "init", + lambda **kwargs: init_calls.append(kwargs) or object(), + ) + monkeypatch.setattr( + server.wandb, + "log", + lambda metrics, step=None: logged_calls.append((metrics, step)), + ) + monkeypatch.setattr(server.wandb, "finish", lambda *args, **kwargs: None) + + with TestClient(server.app) as client: + response = client.post( + "/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 4, + "max_token_len": 128, + "checkpoint_dir": "/tmp", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": 100, + }, + ) + assert response.status_code == 200 + assert init_calls, "wandb.init should be invoked during trainer registration" + + server.app.state.started = True + + env_response = client.post( + "/register-env", + json={ + "max_token_length": 128, + "desired_name": "env", + "weight": 1.0, + "group_size": 2, + "min_batch_allocation": None, + }, + ) + data = env_response.json() + assert env_response.status_code == 200 + message_bus = data.get("message_bus") + assert message_bus, "message bus details should be returned when enabled" + + endpoint_for_client = message_bus["endpoint"].replace("0.0.0.0", "127.0.0.1") + bus_client = MessageBusClient( + endpoint=endpoint_for_client, + token=message_bus["token"], + ) + payload = { + "type": "metrics", + "metrics": {"train/foo": 1.0}, + "server_metrics": {"server_metric": 2.0}, + "rollouts": [[["hello world", 0.5]]], + "step": 5, + "wandb_prepend": message_bus["wandb_prepend"], + } + asyncio.run(bus_client.send_json(payload)) + try: + for _ in range(100): + if logged_calls: + break + time.sleep(0.05) + finally: + asyncio.run(bus_client.close()) + + assert logged_calls, "Metrics sent over the message bus should reach wandb.log" + metrics_logged, step_logged = logged_calls[-1] + expected_prefix = message_bus["wandb_prepend"] + assert metrics_logged[f"{expected_prefix}_train/foo"] == 1.0 + assert metrics_logged["server_metric"] == 2.0 + rollout_key = f"{expected_prefix}_train/rollouts" + assert rollout_key in metrics_logged + assert metrics_logged[rollout_key].rows[0] == ("hello world", 0.5) + assert step_logged == 5 + + client.get("/reset_data") + + +class DummyMessageBusClient: + def __init__(self): + self.messages: List[Dict[str, Any]] = [] + + async def send_json(self, payload: Dict[str, Any]) -> None: + self.messages.append(payload) + + +class DummyServer: + async def wandb_metrics( + self, metrics_dict: Optional[Dict[str, Any]], server_name: Optional[str] + ): + metrics_dict = metrics_dict or {} + metrics_dict[f"{server_name}_latency"] = np.float32(3.25) + return metrics_dict + + +def _build_dummy_env(use_bus: bool, monkeypatch) -> Tuple[Any, DummyMessageBusClient]: + dummy_client = DummyMessageBusClient() if use_bus else None + + server_manager = SimpleNamespace(servers=[DummyServer()]) + config = SimpleNamespace( + use_wandb=True, + num_rollouts_per_group_for_logging=1, + group_size=2, + num_rollouts_to_keep=8, + ) + + dummy = SimpleNamespace( + server=server_manager, + message_bus_client=dummy_client, + config=config, + completion_lengths=[1, 3], + rollouts_for_wandb=[[("decoded text", 0.9)]], + max_token_len=128, + wandb_prepend="env_0", + curr_step=7, + env_id=42, + message_bus_env_name="env", + message_bus_wandb_prepend="env_0", + task_duration=[0.5, 0.7], + succeeded_task_duration=[0.3, 0.4], + failed_task_duration=[], + mainloop_timings=[0.1, 0.2], + workers_added_list=[1, 2], + ) + + dummy._sanitize_for_json = BaseEnv._sanitize_for_json.__get__(dummy, BaseEnv) + dummy.create_rollout_table = BaseEnv.create_rollout_table.__get__(dummy, BaseEnv) + dummy.perf_stats = BaseEnv.perf_stats.__get__(dummy, BaseEnv) + + if not use_bus: + monkeypatch.setattr("atroposlib.envs.base.wandb.Table", DummyTable) + + return dummy, dummy_client + + +def test_base_env_wandb_log_uses_message_bus(monkeypatch): + dummy_env, dummy_client = _build_dummy_env(use_bus=True, monkeypatch=monkeypatch) + + def fail_log(*args, **kwargs): + raise AssertionError( + "wandb.log should not be called when message bus is active" + ) + + monkeypatch.setattr("atroposlib.envs.base.wandb.log", fail_log) + + asyncio.run( + BaseEnv.wandb_log( + dummy_env, + {"custom_metric": np.float32(1.5)}, + ) + ) + + assert dummy_client.messages, "Message bus client should receive payloads" + payload = dummy_client.messages[0] + assert payload["env_id"] == 42 + assert payload["metrics"]["custom_metric"] == pytest.approx(1.5) + assert payload["server_metrics"]["server_0_latency"] == pytest.approx(3.25) + assert payload["rollouts"][0][0][0] == "decoded text" + assert dummy_env.rollouts_for_wandb == [] + assert dummy_env.completion_lengths == [] + + +def test_base_env_wandb_log_falls_back_to_local_wandb(monkeypatch): + dummy_env, _ = _build_dummy_env(use_bus=False, monkeypatch=monkeypatch) + + logged_calls: List[Tuple[Dict[str, Any], Any]] = [] + + monkeypatch.setattr( + "atroposlib.envs.base.wandb.log", + lambda metrics, step=None: logged_calls.append((metrics, step)), + ) + + asyncio.run( + BaseEnv.wandb_log( + dummy_env, + {"local_metric": 2.0}, + ) + ) + + assert logged_calls, "wandb.log should be called when message bus is absent" + metrics_logged, step_logged = logged_calls[0] + assert metrics_logged["env_0_local_metric"] == 2.0 + rollout_table = metrics_logged["env_0_train/rollouts"] + assert isinstance(rollout_table, DummyTable) + assert rollout_table.rows[0] == ("decoded text", 0.9) + assert step_logged == 7 diff --git a/atroposlib/utils/message_bus.py b/atroposlib/utils/message_bus.py new file mode 100644 index 000000000..d13345e99 --- /dev/null +++ b/atroposlib/utils/message_bus.py @@ -0,0 +1,57 @@ +import asyncio +from typing import Any, Dict + +import zmq +import zmq.asyncio + + +class MessageBusClient: + """ZMQ client used by environments to publish payloads.""" + + def __init__( + self, + endpoint: str, + token: str, + *, + linger: int = 0, + snd_hwm: int = 0, + ) -> None: + self._endpoint = endpoint + self._token = token + self._context = zmq.asyncio.Context.instance() + self._socket = self._context.socket(zmq.PUSH) + self._socket.setsockopt(zmq.LINGER, linger) + if snd_hwm >= 0: + self._socket.setsockopt(zmq.SNDHWM, snd_hwm) + self._socket.connect(endpoint) + self._lock = asyncio.Lock() + self._closed = False + + @property + def endpoint(self) -> str: + return self._endpoint + + @property + def token(self) -> str: + return self._token + + async def send_json(self, payload: Dict[str, Any]) -> None: + """Send a JSON serialisable payload over the message bus.""" + if self._closed: + return + message = dict(payload) + message.setdefault("token", self._token) + async with self._lock: + await self._socket.send_json(message) + + async def close(self) -> None: + if self._closed: + return + self._closed = True + self._socket.close(linger=0) + + def __del__(self) -> None: + try: + self._socket.close(linger=0) + except Exception: + pass diff --git a/pyproject.toml b/pyproject.toml index dd3841eeb..b924890fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "markdown", "numpy", "wandb", + "pyzmq", "gymnasium", "math-verify==0.7.0", "jinja2",