diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index edad30414d..fac07bc5ac 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1963,6 +1963,68 @@ def __post_init__(self): raise ValueError("admin_api_key must not be empty or whitespace-only") +@dataclass +class AgentConfig: + """Configuration for the experimental agent service controller.""" + + agent_cls_path: str = field( + default="", + metadata={ + "help": "Fully-qualified import path for the AgentRunnable implementation." + }, + ) + admin_api_key: str = field( + default="areal-agent-admin", + metadata={"help": "Shared admin API key for agent-service inter-service auth."}, + ) + num_pairs: int = field( + default=1, + metadata={"help": "Number of Worker+DataProxy pairs to launch on initialize."}, + ) + setup_timeout: float = field( + default=120.0, + metadata={ + "help": "Timeout in seconds waiting for each service to become healthy." + }, + ) + health_poll_interval: float = field( + default=5.0, + metadata={ + "help": "Seconds between pair health polls; 0 disables health monitoring." + }, + ) + drain_timeout: float = field( + default=30.0, + metadata={ + "help": "Seconds to wait for active sessions to drain before force-killing a pair." + }, + ) + log_level: str = field( + default="info", + metadata={"help": "Log level for spawned agent-service micro-services."}, + ) + env: dict[str, str] = field( + default_factory=dict, + metadata={ + "help": "Extra environment variables passed to all forked child processes." + }, + ) + + def __post_init__(self) -> None: + if not self.agent_cls_path: + raise ValueError("agent_cls_path must be a non-empty import path") + if self.num_pairs < 0: + raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") + if self.setup_timeout <= 0: + raise ValueError( + f"setup_timeout must be positive, got {self.setup_timeout}" + ) + if self.drain_timeout < 0: + raise ValueError( + f"drain_timeout must be non-negative, got {self.drain_timeout}" + ) + + @dataclass class InferenceEngineConfig: """Configuration for inference servers, including offpolicyness control.""" @@ -2081,6 +2143,61 @@ class InferenceEngineConfig: }, ) + # v2 controller options + _version: str = field( + default="v1", + metadata={ + "help": "Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2.", + "choices": ["v1", "v2"], + }, + ) + model: str = field( + default="default", + metadata={"help": "Model name exposed through the inference-service gateway."}, + ) + routing_strategy: str = field( + default="round_robin", + metadata={"help": "Routing strategy for the inference-service router."}, + ) + poll_interval: float = field( + default=5.0, + metadata={ + "help": "Health-poll interval in seconds for the inference-service router." + }, + ) + set_reward_finish_timeout: float = field( + default=0.0, + metadata={ + "help": "Timeout in seconds to wait for additional reward updates before finalizing a session." + }, + ) + log_level: str = field( + default="info", + metadata={"help": "Log level for inference-service micro-services."}, + ) + admin_api_key: str = field( + default="areal-admin-key", + metadata={ + "help": "Admin API key used by the inference-service gateway, router, and data proxies." + }, + ) + api_url: str | None = field( + default=None, + metadata={ + "help": "External OpenAI-compatible base URL for inference-service external model mode." + }, + ) + provider_api_key: str | None = field( + default=None, + metadata={"help": "API key for the external OpenAI-compatible provider."}, + ) + n_gpus_per_node: int | None = field( + default=None, + metadata={ + "help": "GPUs per physical node for multinode inference-service launch." + }, + ) + def __post_init__(self): """Validate scheduling_spec length.""" if len(self.scheduling_spec) not in (1, 2): @@ -2088,6 +2205,25 @@ def __post_init__(self): f"scheduling_spec must contain 1 or 2 SchedulingSpec, " f"got {len(self.scheduling_spec)}" ) + if self._version not in ("v1", "v2"): + raise ValueError( + f"_version must be either 'v1' or 'v2', got '{self._version}'" + ) + if self.n_gpus_per_node is not None and self.n_gpus_per_node < 1: + raise ValueError( + f"n_gpus_per_node must be >= 1, got {self.n_gpus_per_node}" + ) + if not self.admin_api_key or not self.admin_api_key.strip(): + raise ValueError("admin_api_key must not be empty or whitespace-only") + if ( + self._version == "v2" + and self.openai is not None + and self.openai.admin_api_key != "areal-admin-key" + ): + logger.warning( + "rollout.openai.admin_api_key is ignored by rollout controller v2; " + "use rollout.admin_api_key instead." + ) @dataclass diff --git a/areal/experimental/agent_service/README.md b/areal/experimental/agent_service/README.md index f3dc0f839f..c496f61205 100644 --- a/areal/experimental/agent_service/README.md +++ b/areal/experimental/agent_service/README.md @@ -4,7 +4,9 @@ The Agent Service provides **agent-level** capabilities on top of AReaL's model-level proxy. It exposes complete agent sessions — multi-turn conversations with tool use, -memory, and pluggable agent frameworks — via independent HTTP microservices. +memory, and pluggable agent frameworks — via independent HTTP microservices. It also +includes an `AgentController` that can launch the stack through Guard processes and +bridge agent conversations to the experimental inference service for RL data collection. ## Architecture @@ -47,6 +49,10 @@ at startup. Each `POST /run` request is a single turn — the agent receives the conversation history in the request and returns a response. The Worker has no session state. +**AgentController** — Python orchestrator that launches Guards via the scheduler, forks +the Router / Gateway / Worker+DataProxy pairs onto them, supports scale-up and +scale-down, and exposes async runtime APIs for inference-backed RL sessions. + ## Agent Protocol Any class that satisfies the `AgentRunnable` protocol can run on the Worker: @@ -129,6 +135,29 @@ class EventEmitter(Protocol): | `/ws` | WS | Gateway WebSocket protocol | | `/v1/responses` | POST | OpenResponses HTTP bridge | +## AgentController Runtime APIs + +`AgentController` is the integration point used by the examples and rollout workflows. +It manages the agent-service stack and exposes async helpers for RL/inference flows: + +| Method | Description | +| ----------------------------------------------------- | ------------------------------------------------------------------------------ | +| `initialize()` | Launch Guards, Router, Worker+DataProxy pairs, Gateway, and the health monitor | +| `destroy()` | Tear down the full stack in reverse order | +| `scale_up(count)` | Add Worker+DataProxy pairs | +| `scale_down(count)` | Unregister, drain, and remove pairs | +| `start_session(...)` | Grant inference capacity and create an RL session bound to an agent session | +| `step(input, session_id, metadata=None)` | Send a turn through the agent-service Gateway `POST /v1/responses` | +| `set_reward(reward, session_id, interaction_id=None)` | Forward the final reward to the inference service | +| `export_trajectory(session_id, ...)` | Export serialized interactions from the inference service | + +Typical rollout flow: + +1. `start_session()` to create the agent/inference session pair. +1. `step()` for each user turn. +1. `set_reward()` when the episode completes. +1. `export_trajectory()` to retrieve interactions for training. + ## Multi-turn Conversation Flow ``` @@ -159,9 +188,8 @@ areal/experimental/agent_service/ ├── protocol.py # Gateway protocol frame types ├── types.py # AgentRequest, AgentResponse, EventEmitter, AgentRunnable ├── controller/ -│ ├── __init__.py # AgentServiceController, AgentServiceControllerConfig -│ ├── config.py # AgentServiceControllerConfig dataclass -│ └── controller.py # AgentServiceController orchestrator +│ ├── __init__.py # AgentController export +│ └── controller.py # AgentController orchestrator ├── guard/ │ ├── __init__.py # Module docstring │ ├── __main__.py # python -m areal.experimental.agent_service.guard @@ -190,8 +218,25 @@ areal/experimental/agent_service/ ├── app.py # create_worker_app() └── config.py # WorkerConfig dataclass -examples/agent_service/ -├── agent.py # ClaudeAgent (Claude Agent SDK) -├── run_agent_service.py # Controller-based launcher + interactive demo +examples/experimental/agent_service/ +├── __init__.py # Marks the examples package +├── claude/ +│ ├── __init__.py # Claude example package +│ ├── agent.py # ClaudeAgent (Claude Agent SDK) +│ └── run_agent_service.py # Controller-based launcher + interactive demo +├── tau2/ +│ ├── __init__.py # Tau2 example package +│ ├── agent.py # Tau2 agent-service worker example +│ ├── workflow.py # Tau2 workflow using async controller APIs +│ ├── run_rollout.py # Direct rollout driver for Tau2 +│ └── config.yaml # Tau2 example config └── README.md # Example documentation ``` + +For a standalone worker process, the agent import path now points at the nested Claude +example module: + +```bash +python -m areal.experimental.agent_service.worker \ + --agent examples.experimental.agent_service.claude.agent.ClaudeAgent +``` diff --git a/areal/experimental/agent_service/__init__.py b/areal/experimental/agent_service/__init__.py index 3858d5c133..2c964e550f 100644 --- a/areal/experimental/agent_service/__init__.py +++ b/areal/experimental/agent_service/__init__.py @@ -8,7 +8,7 @@ Submodules ---------- -- ``controller`` — :class:`AgentServiceController` orchestrator +- ``controller`` — :class:`AgentController` orchestrator - ``gateway`` — public HTTP/WebSocket entry point - ``router`` — session-affine routing - ``data_proxy`` — stateful session proxy diff --git a/areal/experimental/agent_service/controller/__init__.py b/areal/experimental/agent_service/controller/__init__.py index 3150205885..a677fc8b08 100644 --- a/areal/experimental/agent_service/controller/__init__.py +++ b/areal/experimental/agent_service/controller/__init__.py @@ -2,10 +2,11 @@ """Agent Service Controller — orchestrator for agent micro-services.""" -from .config import AgentServiceControllerConfig -from .controller import AgentServiceController +from areal.api.cli_args import AgentConfig + +from .controller import AgentController __all__ = [ - "AgentServiceController", - "AgentServiceControllerConfig", + "AgentController", + "AgentConfig", ] diff --git a/areal/experimental/agent_service/controller/config.py b/areal/experimental/agent_service/controller/config.py deleted file mode 100644 index c316d58227..0000000000 --- a/areal/experimental/agent_service/controller/config.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Configuration for the AgentServiceController.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - -from ..auth import DEFAULT_ADMIN_API_KEY - - -@dataclass -class AgentServiceControllerConfig: - """Unified configuration for AgentServiceController. - - Consolidates settings for the guard, router, gateway, worker, and - data proxy micro-services launched by the controller. - """ - - # -- Agent class ------------------------------------------------------- - agent_cls_path: str = "" - """Fully-qualified import path for the ``AgentRunnable`` implementation - (e.g. ``examples.agent_service.agent.Tau2Agent``).""" - - # -- Authentication ---------------------------------------------------- - admin_api_key: str = DEFAULT_ADMIN_API_KEY - """Shared admin API key for inter-service Bearer auth.""" - - # -- Scaling ----------------------------------------------------------- - num_pairs: int = 1 - """Number of Worker+DataProxy pairs to launch on initialize.""" - - # -- Timeouts ---------------------------------------------------------- - setup_timeout: float = 120.0 - """Timeout (seconds) waiting for each service to become healthy.""" - - health_poll_interval: float = 5.0 - """Seconds between health polls for crash detection (0 = disabled).""" - - drain_timeout: float = 30.0 - """Seconds to wait for active sessions to drain before force-killing a pair.""" - - # -- Log level --------------------------------------------------------- - log_level: str = "info" - """Log level for spawned micro-services.""" - - # -- Environment ------------------------------------------------------- - env: dict[str, str] = field(default_factory=dict) - """Extra environment variables to pass to all forked child processes.""" - - def __post_init__(self) -> None: - if not self.agent_cls_path: - raise ValueError("agent_cls_path must be a non-empty import path") - if self.num_pairs < 0: - raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") - if self.setup_timeout <= 0: - raise ValueError( - f"setup_timeout must be positive, got {self.setup_timeout}" - ) - if self.drain_timeout < 0: - raise ValueError( - f"drain_timeout must be non-negative, got {self.drain_timeout}" - ) diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py index 21b12851bb..19f662e331 100644 --- a/areal/experimental/agent_service/controller/controller.py +++ b/areal/experimental/agent_service/controller/controller.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -"""AgentServiceController — orchestrates agent service micro-services via Guards. +"""AgentController — orchestrates agent service micro-services via Guards. Mirrors the architecture of -:class:`~areal.experimental.inference_service.controller.controller.GatewayInferenceController`: +:class:`~areal.experimental.inference_service.controller.controller.RolloutControllerV2`: Guard workers are created via the Scheduler, then the controller forks Router, Worker+DataProxy pairs, and Gateway onto them via HTTP API. @@ -12,7 +12,7 @@ from areal.infra.scheduler.local import LocalScheduler scheduler = LocalScheduler(...) - controller = AgentServiceController(config, scheduler) + controller = AgentController(config, scheduler) controller.initialize() # ... run traffic ... controller.scale_up(2) # add 2 Worker+DataProxy pairs @@ -26,26 +26,30 @@ import threading import time import traceback +import uuid from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import TYPE_CHECKING, Any +import aiohttp import requests -from areal.experimental.agent_service.controller.config import ( - AgentServiceControllerConfig, -) +from areal.api.cli_args import AgentConfig +from areal.experimental.openai.proxy.server import deserialize_interactions from areal.utils import logging from areal.utils.network import format_hostport if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker + from areal.experimental.openai.types import InteractionWithTokenLogpReward -logger = logging.getLogger("AgentServiceController") +logger = logging.getLogger("AgentController") _GUARD_ROLE = "agent-guard" _UNREGISTER_RETRIES = 3 _HEALTH_CHECK_WORKERS = 4 +_DEFAULT_RUNTIME_TIMEOUT = 600.0 +_DEFAULT_INFERENCE_ADMIN_API_KEY = "areal-admin-key" @dataclass @@ -60,7 +64,17 @@ class _WorkerPair: worker_addr: str -class AgentServiceController: +@dataclass +class _RuntimeSession: + agent_session_id: str + inference_gateway_addr: str + inference_admin_api_key: str + inference_session_id: str + inference_session_api_key: str + inference_model: str = "" + + +class AgentController: """Orchestrator for the Agent Service micro-service stack. Parameters @@ -73,7 +87,7 @@ class AgentServiceController: def __init__( self, - config: AgentServiceControllerConfig, + config: AgentConfig, scheduler: Scheduler, ) -> None: self.config = config @@ -91,6 +105,8 @@ def __init__( self._next_pair_index: int = 0 self._forked_services: list[tuple[str, str, int]] = [] + self._sessions: dict[str, _RuntimeSession] = {} + self._sessions_lock = threading.Lock() self._health_stop = threading.Event() self._health_thread: threading.Thread | None = None @@ -219,6 +235,8 @@ def destroy(self) -> None: self._guard_addrs.clear() with self._pairs_lock: self._pairs.clear() + with self._sessions_lock: + self._sessions.clear() self._router_addr = "" self._gateway_addr = "" @@ -371,6 +389,166 @@ def pairs(self) -> dict[int, _WorkerPair]: with self._pairs_lock: return dict(self._pairs) + # ------------------------------------------------------------------ + # Runtime APIs + # ------------------------------------------------------------------ + + @staticmethod + def _bearer_headers(api_key: str) -> dict[str, str]: + return {"Authorization": f"Bearer {api_key}"} + + async def _post_json( + self, + url: str, + payload: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + *, + timeout: float = _DEFAULT_RUNTIME_TIMEOUT, + expect_json: bool = True, + ) -> dict[str, Any]: + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.post(url, json=payload, headers=headers) as resp: + resp.raise_for_status() + if not expect_json: + return {} + return await resp.json() + + async def _grant_capacity( + self, + inference_gateway_addr: str, + inference_admin_api_key: str, + ) -> None: + await self._post_json( + f"{inference_gateway_addr.rstrip('/')}/grant_capacity", + headers=self._bearer_headers(inference_admin_api_key), + expect_json=False, + ) + + async def start_session( + self, + task_id: str, + *, + inference_gateway_addr: str, + inference_admin_api_key: str = _DEFAULT_INFERENCE_ADMIN_API_KEY, + inference_model: str = "", + api_key: str | None = None, + ) -> dict[str, str]: + agent_session_id = f"agent-sess-{uuid.uuid4().hex[:12]}" + normalized_task_id = task_id or agent_session_id + gateway_addr = inference_gateway_addr.rstrip("/") + + await self._grant_capacity(gateway_addr, inference_admin_api_key) + + payload: dict[str, Any] = {"task_id": normalized_task_id} + if api_key is not None: + payload["api_key"] = api_key + + data = await self._post_json( + f"{gateway_addr}/rl/start_session", + payload=payload, + headers=self._bearer_headers(inference_admin_api_key), + ) + + session = _RuntimeSession( + agent_session_id=agent_session_id, + inference_gateway_addr=gateway_addr, + inference_admin_api_key=inference_admin_api_key, + inference_session_id=data["session_id"], + inference_session_api_key=data["api_key"], + inference_model=inference_model, + ) + with self._sessions_lock: + self._sessions[agent_session_id] = session + + return { + "session_id": agent_session_id, + "inference_session_id": session.inference_session_id, + "api_key": session.inference_session_api_key, + } + + async def step( + self, + input: str | list[dict[str, Any]], + session_id: str, + *, + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if not self._gateway_addr: + raise RuntimeError( + "step() requires the agent-service gateway to be running" + ) + + session = self._resolve_session(session_id) + input_items = ( + [{"type": "message", "content": input}] if isinstance(input, str) else input + ) + + merged_metadata: dict[str, Any] = { + "inference_base_url": session.inference_gateway_addr, + "inference_api_key": session.inference_session_api_key, + } + if session.inference_model: + merged_metadata["inference_model"] = session.inference_model + if metadata: + merged_metadata.update(metadata) + + body: dict[str, Any] = { + "input": input_items, + "model": (session.inference_model or "default").replace("/", "--"), + "user": session.agent_session_id, + } + if merged_metadata: + body["metadata"] = merged_metadata + + return await self._post_json( + f"{self._gateway_addr}/v1/responses", + payload=body, + headers=self._bearer_headers(self.config.admin_api_key), + ) + + async def set_reward( + self, + reward: float, + session_id: str, + *, + interaction_id: str | None = None, + ) -> dict[str, Any]: + session = self._resolve_session(session_id) + return await self._post_json( + f"{session.inference_gateway_addr}/rl/set_reward", + payload={"interaction_id": interaction_id, "reward": reward}, + headers=self._bearer_headers(session.inference_session_api_key), + ) + + async def export_trajectory( + self, + session_id: str, + *, + trajectory_id: int | None = None, + discount: float = 1.0, + style: str = "individual", + ) -> dict[str, InteractionWithTokenLogpReward]: + session = self._resolve_session(session_id) + data = await self._post_json( + f"{session.inference_gateway_addr}/export_trajectories", + payload={ + "session_id": session.inference_session_id, + "trajectory_id": trajectory_id, + "discount": discount, + "style": style, + }, + headers=self._bearer_headers(session.inference_admin_api_key), + ) + return deserialize_interactions(data["interactions"]) + + def _resolve_session(self, session_id: str) -> _RuntimeSession: + with self._sessions_lock: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(f"Unknown session_id: {session_id!r}") + return session + # ------------------------------------------------------------------ # Guard interaction helpers # ------------------------------------------------------------------ diff --git a/areal/experimental/agent_service/worker/__main__.py b/areal/experimental/agent_service/worker/__main__.py index c14d52eba7..8a2ec97654 100644 --- a/areal/experimental/agent_service/worker/__main__.py +++ b/areal/experimental/agent_service/worker/__main__.py @@ -6,7 +6,7 @@ via Guard to create Worker+DataProxy pairs. python -m areal.experimental.agent_service.worker \ - --agent examples.agent_service.agent.ClaudeAgent \ + --agent examples.experimental.agent_service.claude.agent.ClaudeAgent \ --host 127.0.0.1 --port 9000 """ diff --git a/areal/experimental/inference_service/controller/__init__.py b/areal/experimental/inference_service/controller/__init__.py index e122baa6db..9ac57d0d68 100644 --- a/areal/experimental/inference_service/controller/__init__.py +++ b/areal/experimental/inference_service/controller/__init__.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) -__all__ = ["GatewayControllerConfig", "GatewayInferenceController"] +__all__ = ["InferenceEngineConfig", "RolloutControllerV2"] diff --git a/areal/experimental/inference_service/controller/config.py b/areal/experimental/inference_service/controller/config.py deleted file mode 100644 index 4d45b1391b..0000000000 --- a/areal/experimental/inference_service/controller/config.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Configuration for the GatewayInferenceController.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - - -@dataclass -class GatewayControllerConfig: - """Unified configuration for GatewayInferenceController. - - Consolidates settings for the gateway, router, data proxy services, - and the WorkflowExecutor / staleness management. - """ - - # -- Model / tokenizer ------------------------------------------------- - tokenizer_path: str = "" - model_path: str = "" - model: str = "default" - - # -- Routing ----------------------------------------------------------- - routing_strategy: str = "round_robin" - poll_interval: float = 5.0 # router health-poll interval (seconds) - - # -- HTTP timeouts ----------------------------------------------------- - request_timeout: float = 120.0 # per-request timeout (seconds) - setup_timeout: float = 300.0 # timeout waiting for services to start - set_reward_finish_timeout: float = 0.0 - - # -- Log level for gateway micro-services ------------------------------ - log_level: str = "info" - - # -- WorkflowExecutor / staleness -------------------------------------- - consumer_batch_size: int = 16 - max_concurrent_rollouts: int | None = None - max_head_offpolicyness: int = 0 - queue_size: int | None = None - enable_rollout_tracing: bool = False - - # -- Trajectory dump --------------------------------------------------- - fileroot: str | None = None - experiment_name: str | None = None - trial_name: str | None = None - check_trajectory_format: bool = False - dump_to_file: bool = False - - # -- Scheduler / allocation (passed through from trainer) -------------- - backend: str = "sglang:d1" - scheduling_spec: tuple = field(default_factory=tuple) - pause_grace_period: float = 0.5 - n_gpus_per_node: int | None = None # GPUs per physical node; None = single-node - - # -- Admin / workflow -------------------------------------------------- - admin_api_key: str | None = None - turn_discount: float = 1.0 - export_style: str = "individual" - tool_call_parser: str = "qwen" - reasoning_parser: str = "qwen3" - engine_max_tokens: int | None = None - chat_template_type: str = "hf" - - # -- External model API ------------------------------------------------ - api_url: str | None = None - provider_api_key: str | None = None diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index 7cc546a554..45de0ea17d 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""GatewayInferenceController — parallel implementation to RolloutController. +"""RolloutControllerV2 — parallel implementation to RolloutController. Routes inference and pause/continue traffic through the gateway HTTP stack (Gateway → Router → Data Proxy → inference backend). @@ -11,6 +11,7 @@ from __future__ import annotations import asyncio +import copy import os import sys import threading @@ -28,14 +29,12 @@ if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker +from areal.api.cli_args import InferenceEngineConfig from areal.api.io_struct import LocalInfServerInfo -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) from areal.utils import logging from areal.utils.network import format_hostport -logger = logging.getLogger("GatewayInferenceController") +logger = logging.getLogger("RolloutControllerV2") _MAX_COMPLETED_ONLINE_RESULTS = 1024 @@ -48,7 +47,7 @@ class _OnlineWaiter: class _DummyDataLoader: """Minimal dataloader that yields a single batch of empty dicts. - Used by :meth:`GatewayInferenceController.prepare_batch` when + Used by :meth:`RolloutControllerV2.prepare_batch` when ``dataloader`` is ``None`` (online-agent mode). """ @@ -59,7 +58,7 @@ def __iter__(self): yield [{} for _ in range(self.batch_size)] -class GatewayInferenceController: +class RolloutControllerV2: """Inference controller that routes everything through the gateway HTTP stack. This is a **parallel** implementation to ``RolloutController`` (NOT a @@ -79,15 +78,15 @@ class GatewayInferenceController: def __init__( self, - config: GatewayControllerConfig, + config: InferenceEngineConfig, scheduler: Scheduler, ) -> None: - if config.admin_api_key is None: + if config.admin_api_key is None or not config.admin_api_key.strip(): raise ValueError( - "GatewayControllerConfig.admin_api_key must be set (not None)" + "InferenceEngineConfig.admin_api_key must be set (not None or empty)" ) if not config.model: - raise ValueError("GatewayControllerConfig.model must not be empty") + raise ValueError("InferenceEngineConfig.model must not be empty") self.config = config self.scheduler = scheduler @@ -199,7 +198,6 @@ def initialize( self._register_data_proxies_in_router() # Create WorkflowExecutor directly (no intermediate engine) - from areal.api.cli_args import InferenceEngineConfig from areal.infra.remote_inf_engine import RemoteInfEngine from areal.infra.workflow_executor import WorkflowExecutor @@ -222,7 +220,7 @@ def initialize( max_staleness=self.config.max_head_offpolicyness, ) - logger.info("GatewayInferenceController initialized (role=%s)", role) + logger.info("RolloutControllerV2 initialized (role=%s)", role) if self.config.model: self.register_model( @@ -237,6 +235,18 @@ def initialize( self.config.model, ) + def offload(self) -> None: + """Offload hook placeholder for trainer compatibility.""" + logger.warning( + "RolloutControllerV2.offload is not implemented and will be skipped" + ) + + def onload(self, tags: list[str] | None = None) -> None: + """Onload hook placeholder for trainer compatibility.""" + logger.warning( + "RolloutControllerV2.onload is not implemented and will be skipped" + ) + async def _async_initialize( self, server_args: dict[str, Any] | None, @@ -266,6 +276,7 @@ async def _async_initialize( cfg = self.config admin_api_key = self.config.admin_api_key + openai_cfg = self._openai_config if self.external_mode: dp_size = 1 @@ -350,10 +361,9 @@ async def _async_initialize( if inf_backend == "sglang": from areal.api.cli_args import SGLangConfig - sglang_config = SGLangConfig( - model_path=cfg.model_path or cfg.tokenizer_path, - ) + sglang_config = SGLangConfig() if server_args: + sglang_config = copy.deepcopy(sglang_config) for k, v in server_args.items(): if hasattr(sglang_config, k): setattr(sglang_config, k, v) @@ -386,17 +396,19 @@ def _build_launch_cmd( elif inf_backend == "vllm": from areal.api.cli_args import vLLMConfig - vllm_config = vLLMConfig(model=cfg.model_path or cfg.tokenizer_path) - for k, v in (server_args or {}).items(): - if hasattr(vllm_config, k): - setattr(vllm_config, k, v) - else: - logger.warning( - "vLLMConfig has no attribute %r, ignoring " - "server_args entry (value=%r)", - k, - v, - ) + vllm_config = vLLMConfig() + if server_args: + vllm_config = copy.deepcopy(vllm_config) + for k, v in server_args.items(): + if hasattr(vllm_config, k): + setattr(vllm_config, k, v) + else: + logger.warning( + "vLLMConfig has no attribute %r, ignoring " + "server_args entry (value=%r)", + k, + v, + ) def _build_launch_cmd( host: str | None, @@ -583,16 +595,16 @@ def _build_launch_cmd( "--callback-server-addr", f"http://{self.callback_addr}", "--tool-call-parser", - cfg.tool_call_parser, + openai_cfg.tool_call_parser, "--reasoning-parser", - cfg.reasoning_parser, + openai_cfg.reasoning_parser, "--chat-template-type", - cfg.chat_template_type, + openai_cfg.chat_template_type, ] - if cfg.engine_max_tokens is not None: + if openai_cfg.engine_max_tokens is not None: data_proxy_base_cmd += [ "--engine-max-tokens", - str(cfg.engine_max_tokens), + str(openai_cfg.engine_max_tokens), ] for group_idx in range(dp_size): @@ -951,9 +963,7 @@ def get_version(self) -> int: def get_capacity(self) -> int: if self.staleness_manager is None: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") return self.staleness_manager.get_capacity() # -- Submit / Wait / Batch --------------------------------------------- @@ -1044,9 +1054,7 @@ def rollout_batch( A list of trajectory dicts (one per completed rollout). """ if not self._gateway_addr: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") if data is None: if batch_size is None: raise ValueError( @@ -1115,9 +1123,7 @@ def prepare_batch( A list of trajectory dicts (matching ``RolloutController`` API). """ if not self._gateway_addr: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") if dataloader is None: if batch_size is None: raise ValueError( @@ -1321,9 +1327,7 @@ def staleness_manager(self): @property def workflow_executor(self): if self._workflow_executor is None: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") return self._workflow_executor @property @@ -1353,8 +1357,8 @@ def _wrap_agent(self, agent: Any): "Gateway address is unavailable; initialize the controller first" ) - openai_cfg = self.config - admin_api_key = openai_cfg.admin_api_key + openai_cfg = self._openai_config + admin_api_key = self.config.admin_api_key turn_discount = openai_cfg.turn_discount export_style = openai_cfg.export_style @@ -1440,13 +1444,13 @@ def _resolve_workflow( # (c) Reject RolloutWorkflow classes and instances if isinstance(agent, type) and issubclass(agent, RolloutWorkflow): raise TypeError( - "GatewayInferenceController only accepts agent classes or instances with a " + "RolloutControllerV2 only accepts agent classes or instances with a " "run() method or None for online mode; direct RolloutWorkflow " "classes are not supported" ) if isinstance(agent, RolloutWorkflow): raise TypeError( - "GatewayInferenceController only accepts agent classes or instances with a " + "RolloutControllerV2 only accepts agent classes or instances with a " "run() method or None for online mode; direct RolloutWorkflow " "instances are not supported" ) @@ -1490,6 +1494,14 @@ def _resolve_should_accept_fn( return cast(Callable[[dict[str, Any]], bool], func) raise TypeError(f"Invalid should_accept_fn type: {type(should_accept_fn)}") + @property + def _openai_config(self): + from areal.api.cli_args import OpenAIProxyConfig + + return self.config.openai or OpenAIProxyConfig( + admin_api_key=self.config.admin_api_key + ) + # -- Internal HTTP helpers --------------------------------------------- def _fork_on_guard( diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index 95f0571770..3bbfeb8b97 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from areal.api.engine_api import InferenceEngine from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.openai.types import InteractionWithTokenLogpReward @@ -29,7 +29,7 @@ class InferenceServiceWorkflow(RolloutWorkflow): def __init__( self, - controller: GatewayInferenceController, + controller: RolloutControllerV2, agent: Any | None = None, gateway_addr: str = "", admin_api_key: str = "areal-admin-key", diff --git a/areal/infra/data_service/controller/controller.py b/areal/infra/data_service/controller/controller.py index 5bfd2ae7ab..67188b8150 100644 --- a/areal/infra/data_service/controller/controller.py +++ b/areal/infra/data_service/controller/controller.py @@ -5,7 +5,7 @@ Manages the full lifecycle: create RPCGuard workers → fork DataWorkers, Router, Gateway → register datasets → serve batches → shutdown. -Follows the same patterns as ``GatewayInferenceController``. +Follows the same patterns as ``RolloutControllerV2``. """ from __future__ import annotations @@ -33,7 +33,7 @@ class DataController: """Controller for the distributed data loading service. - API follows ``TrainController`` / ``GatewayInferenceController`` patterns: + API follows ``TrainController`` / ``RolloutControllerV2`` patterns: ``__init__(config, scheduler)`` then ``initialize(role, ...)``. """ diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 978a6c18de..61fd4f994a 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -6,7 +6,7 @@ import os from collections.abc import Callable from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader @@ -35,6 +35,9 @@ vLLMConfig, ) from areal.engine import RemoteSGLangEngine, RemotevLLMEngine +from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, +) from areal.infra import ( LocalScheduler, RayScheduler, @@ -1008,7 +1011,12 @@ def _init_rollout( return engine # Single-controller mode - no engine instantiation needed - controller = engine_cls.as_controller(config, self.scheduler) + if config._version == "v2": + controller = RolloutControllerV2( + config=config, scheduler=cast(Scheduler, self.scheduler) + ) + else: + controller = engine_cls.as_controller(config, self.scheduler) init_kwargs = dict( role="rollout", server_args=server_args, diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 3d4f83bcd3..8b1130af8c 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -110,9 +110,9 @@ "AgentRouter": "light_purple", "AgentWorker": "light_purple", "AgentDataProxy": "light_purple", - "AgentServiceController": "light_purple", + "AgentController": "light_purple", # Inference service - white (orchestration) - "GatewayInferenceController": "white", + "RolloutControllerV2": "white", "InferenceDataProxy": "white", "InferenceInfBridge": "white", "InferenceRouter": "white", diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 8ea5263c78..abec4059ab 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -71,6 +71,7 @@ For detailed examples, see the experiment configurations in the `examples/` dire ### Others +- [Agent Configuration](section-agent) - [ArchonEngine Configuration](section-archon-engine) - [ArchonFP8 Configuration](section-archon-fp8) - [DPO Configuration](section-dpo) @@ -527,30 +528,40 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | -| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | -| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | -| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| Parameter | Type | Default | Description | +| --------------------------- | ---------------------------------------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | +| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | +| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | +| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | +| `model` | string | `"default"` | Model name exposed through the inference-service gateway. | +| `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | +| `poll_interval` | float | `5.0` | Health-poll interval in seconds for the inference-service router. | +| `set_reward_finish_timeout` | float | `0.0` | Timeout in seconds to wait for additional reward updates before finalizing a session. | +| `log_level` | string | `"info"` | Log level for inference-service micro-services. | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by the inference-service gateway, router, and data proxies. | +| `api_url` | string \| None | `None` | External OpenAI-compatible base URL for inference-service external model mode. | +| `provider_api_key` | string \| None | `None` | API key for the external OpenAI-compatible provider. | +| `n_gpus_per_node` | integer \| None | `None` | GPUs per physical node for multinode inference-service launch. | (section-sg-lang)= @@ -856,6 +867,23 @@ Configuration for Weights & Biases experiment tracking. | `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | +(section-agent)= + +## Agent Configuration + +Configuration for the experimental agent service controller. + +| Parameter | Type | Default | Description | +| ---------------------- | ------- | --------------------- | ------------------------------------------------------------------------- | +| `agent_cls_path` | string | `""` | Fully-qualified import path for the AgentRunnable implementation. | +| `admin_api_key` | string | `"areal-agent-admin"` | Shared admin API key for agent-service inter-service auth. | +| `num_pairs` | integer | `1` | Number of Worker+DataProxy pairs to launch on initialize. | +| `setup_timeout` | float | `120.0` | Timeout in seconds waiting for each service to become healthy. | +| `health_poll_interval` | float | `5.0` | Seconds between pair health polls; 0 disables health monitoring. | +| `drain_timeout` | float | `30.0` | Seconds to wait for active sessions to drain before force-killing a pair. | +| `log_level` | string | `"info"` | Log level for spawned agent-service micro-services. | +| `env` | `dict` | **Required** | Extra environment variables passed to all forked child processes. | + (section-archon-engine)= ## ArchonEngine Configuration diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index b1ca97aa3e..ebff40ba84 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -69,6 +69,7 @@ python3 train.py --config path/to/config.yaml actor.lr=1e-4 seed=42 ### Others +- [Agent Configuration](section-agent) - [ArchonEngine Configuration](section-archon-engine) - [ArchonFP8 Configuration](section-archon-fp8) - [DPO Configuration](section-dpo) @@ -525,30 +526,40 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | -| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | -| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | -| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| Parameter | Type | Default | Description | +| --------------------------- | ---------------------------------------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | +| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | +| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | +| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | +| `model` | string | `"default"` | Model name exposed through the inference-service gateway. | +| `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | +| `poll_interval` | float | `5.0` | Health-poll interval in seconds for the inference-service router. | +| `set_reward_finish_timeout` | float | `0.0` | Timeout in seconds to wait for additional reward updates before finalizing a session. | +| `log_level` | string | `"info"` | Log level for inference-service micro-services. | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by the inference-service gateway, router, and data proxies. | +| `api_url` | string \| None | `None` | External OpenAI-compatible base URL for inference-service external model mode. | +| `provider_api_key` | string \| None | `None` | API key for the external OpenAI-compatible provider. | +| `n_gpus_per_node` | integer \| None | `None` | GPUs per physical node for multinode inference-service launch. | (section-sg-lang)= @@ -854,6 +865,23 @@ Configuration for Weights & Biases experiment tracking. | `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | +(section-agent)= + +## Agent Configuration + +Configuration for the experimental agent service controller. + +| Parameter | Type | Default | Description | +| ---------------------- | ------- | --------------------- | ------------------------------------------------------------------------- | +| `agent_cls_path` | string | `""` | Fully-qualified import path for the AgentRunnable implementation. | +| `admin_api_key` | string | `"areal-agent-admin"` | Shared admin API key for agent-service inter-service auth. | +| `num_pairs` | integer | `1` | Number of Worker+DataProxy pairs to launch on initialize. | +| `setup_timeout` | float | `120.0` | Timeout in seconds waiting for each service to become healthy. | +| `health_poll_interval` | float | `5.0` | Seconds between pair health polls; 0 disables health monitoring. | +| `drain_timeout` | float | `30.0` | Seconds to wait for active sessions to drain before force-killing a pair. | +| `log_level` | string | `"info"` | Log level for spawned agent-service micro-services. | +| `env` | `dict` | **Required** | Extra environment variables passed to all forked child processes. | + (section-archon-engine)= ## ArchonEngine Configuration diff --git a/examples/agent_service/README.md b/examples/agent_service/README.md deleted file mode 100644 index 563064b4e5..0000000000 --- a/examples/agent_service/README.md +++ /dev/null @@ -1,114 +0,0 @@ -# Agent Service — Claude Agent SDK - -## Overview - -This example demonstrates AReaL's Agent Service running the **Claude Agent SDK** -(`claude-agent-sdk`) as a scalable HTTP micro-service. It turns Claude's autonomous -agent capabilities — multi-turn conversations, tool use, file editing, web search — into -a production-deployable service with session management, load balancing, and dynamic -scaling. - -**Why this matters**: Projects like -[claude-agent-acp](https://github.com/agentclientprotocol/claude-agent-acp) expose -Claude Agent SDK via custom protocols (ACP) for editor integration. AReaL takes a -different approach — it wraps Claude Agent SDK into standard HTTP micro-services with -session-affine routing, so you can **scale, orchestrate, and train** Claude agents using -AReaL's RL infrastructure. - -``` -Client → Gateway (HTTP) → Router → DataProxy (session state) → Worker (ClaudeSDKClient) -``` - -## Prerequisites - -```bash -uv pip install claude-agent-sdk -export ANTHROPIC_API_KEY=sk-... -``` - -## Quick Start - -```bash -python examples/agent_service/run_agent_service.py -``` - -The script creates a `LocalScheduler`, launches Guard workers, then forks Router → -Worker+DataProxy → Gateway. An interactive prompt lets you chat with the Claude agent. - -### Options - -```bash -python examples/agent_service/run_agent_service.py --num-pairs 4 -``` - -### Send requests directly - -```bash -curl -X POST http://localhost:/v1/responses \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer areal-agent-admin" \ - -d '{ - "input": [{"type": "message", "content": "Explain RLHF in simple terms"}], - "model": "claude-agent", - "user": "my-session" - }' -``` - -## Configuration - -Claude Agent SDK settings are controlled via environment variables: - -| Variable | Default | Description | -| ---------------------- | ------------------- | --------------------------- | -| `ANTHROPIC_API_KEY` | (required) | Anthropic API key | -| `CLAUDE_MODEL` | `claude-sonnet-4-6` | Model to use | -| `CLAUDE_SYSTEM_PROMPT` | (none) | Optional system prompt | -| `CLAUDE_MAX_TURNS` | `20` | Max agentic turns per query | - -## Architecture - -The Worker maintains a **session-persistent `ClaudeSDKClient`** per session key. Unlike -stateless wrappers, the SDK's internal session retains the full conversation transcript -— no need to re-send history on each turn. - -``` -Turn 1: Client → Gateway → Router → DataProxy → Worker - Worker: creates ClaudeSDKClient for session "abc" - Claude Agent SDK runs autonomously (tool calls, file ops, etc.) - Response streams back through the chain - -Turn 2: Client → Gateway → Router (same DataProxy) → DataProxy → Worker - Worker: reuses ClaudeSDKClient for session "abc" - SDK remembers full context from Turn 1 -``` - -## Programmatic Usage - -```python -from areal.experimental.agent_service.controller import ( - AgentServiceController, - AgentServiceControllerConfig, -) -from areal.infra.scheduler.local import LocalScheduler - -scheduler = LocalScheduler(experiment_name="demo", trial_name="run0", gpu_devices=[]) -ctrl = AgentServiceController( - config=AgentServiceControllerConfig( - agent_cls_path="examples.agent_service.agent.ClaudeAgent", - num_pairs=2, - ), - scheduler=scheduler, -) -ctrl.initialize() -# ctrl.gateway_addr → "http://10.0.0.1:9005" -# ctrl.scale_up(2) → add 2 more pairs -# ctrl.scale_down(1) → remove 1 pair (with graceful drain) -ctrl.destroy() -``` - -## Files - -| File | Description | -| ---------------------- | ----------------------------------------------------------- | -| `agent.py` | `ClaudeAgent` — session-persistent Claude Agent SDK wrapper | -| `run_agent_service.py` | Controller-based launcher + interactive conversation | diff --git a/examples/experimental/__init__.py b/examples/experimental/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/experimental/agent_service/README.md b/examples/experimental/agent_service/README.md new file mode 100644 index 0000000000..74787b53f3 --- /dev/null +++ b/examples/experimental/agent_service/README.md @@ -0,0 +1,91 @@ +# Agent Service Examples + +## Overview + +This directory contains experimental examples built on top of AReaL's agent service. The +examples are grouped by scenario: + +- `claude/` — a standalone Claude Agent SDK service demo +- `tau2/` — a tau2 customer-service rollout example that combines the agent service with + the experimental inference service + +The agent service exposes complete agent sessions through Router, DataProxy, Worker, and +Gateway microservices, and can be paired with the experimental inference service for RL +data collection. + +## Example 1: Claude Agent SDK Service + +This is the Claude Agent SDK example under the new `claude/` subdirectory. + +### Prerequisites + +```bash +uv pip install claude-agent-sdk +export ANTHROPIC_API_KEY=sk-... +``` + +### Run + +```bash +python examples/experimental/agent_service/claude/run_agent_service.py +python examples/experimental/agent_service/claude/run_agent_service.py --num-pairs 4 +``` + +The script creates a `LocalScheduler`, launches Guard workers, then forks Router, +Worker+DataProxy pairs, and Gateway. An interactive prompt lets you chat with the Claude +agent through `POST /v1/responses`. + +Files: + +- `claude/agent.py` — Claude Agent SDK worker implementation +- `claude/run_agent_service.py` — interactive launcher for the Claude example + +## Example 2: Tau2 Agent Service Rollout + +This example runs the tau2 customer-service agent inside the experimental agent service +while the experimental inference service collects RL trajectories. Unlike the reference +inference-service example, this script initializes `RolloutControllerV2` but does not +use `rollout_batch()`. It directly runs the tau2 workflow and returns exported +trajectories from the inference service. + +### Additional Prerequisites + +```bash +pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion +pip install pydantic-ai +export TAU2_DATA_DIR=/path/to/tau2-bench/data +``` + +If `econfig.solo_mode=false`, also start a user simulator model and set +`econfig.user_llm_base_url` in `tau2/config.yaml`. + +### Run + +```bash +python examples/experimental/agent_service/tau2/run_rollout.py \ + --config examples/experimental/agent_service/tau2/config.yaml \ + cluster.fileroot= \ + cluster.name_resolve.nfs_record_root= +``` + +### What it does + +1. Starts the experimental inference service with `RolloutControllerV2`. +1. Starts the experimental agent service with `AgentController`. +1. For each tau2 task, the workflow: + - calls `AgentController.start_session()` (which grants capacity and starts the RL + session), + - drives the tau2 conversation through `AgentController.step()`, + - calls `AgentController.set_reward()`, + - calls `AgentController.export_trajectory()` and returns the exported interactions. + +### Files + +| File | Description | +| ----------------------------- | ------------------------------------------------- | +| `claude/agent.py` | Claude Agent SDK example agent | +| `claude/run_agent_service.py` | Interactive launcher for the Claude agent service | +| `tau2/agent.py` | Tau2 agent-service worker agent | +| `tau2/workflow.py` | Tau2 rollout workflow using async controller APIs | +| `tau2/run_rollout.py` | Direct rollout driver for the tau2 workflow | +| `tau2/config.yaml` | Example config for the tau2 rollout driver | diff --git a/examples/experimental/agent_service/__init__.py b/examples/experimental/agent_service/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/agent_service/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/experimental/agent_service/claude/__init__.py b/examples/experimental/agent_service/claude/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/agent_service/claude/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/agent_service/agent.py b/examples/experimental/agent_service/claude/agent.py similarity index 81% rename from examples/agent_service/agent.py rename to examples/experimental/agent_service/claude/agent.py index c05f3bebe5..2b3a3d01a3 100644 --- a/examples/agent_service/agent.py +++ b/examples/experimental/agent_service/claude/agent.py @@ -1,19 +1,9 @@ """Claude Agent for AReaL Agent Service. Implements :class:`AgentRunnable` using the Claude Agent SDK -(``claude-agent-sdk``). Each Worker instance holds a pool of +(``claude-agent-sdk``). Each Worker instance holds a pool of :class:`ClaudeSDKClient` sessions keyed by ``session_key``, so multi-turn conversations preserve full context without re-sending history. - -Requires:: - - pip install claude-agent-sdk - -Environment variables: - ANTHROPIC_API_KEY — Anthropic API key (required) - CLAUDE_MODEL — model name (default: claude-sonnet-4-6) - CLAUDE_SYSTEM_PROMPT — optional system prompt override - CLAUDE_MAX_TURNS — max agentic turns per query (default: 20) """ from __future__ import annotations @@ -45,20 +35,14 @@ class ClaudeAgent: - """AgentRunnable backed by the Claude Agent SDK. - - Maintains a ``ClaudeSDKClient`` per session for true multi-turn - continuity — the SDK's internal session keeps the full transcript, - so ``request.history`` is only used for the very first turn of a - new session (to seed context if provided by the caller). - """ + """AgentRunnable backed by the Claude Agent SDK.""" def __init__(self, **kwargs: Any) -> None: + del kwargs self._model = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6") self._system_prompt = os.environ.get("CLAUDE_SYSTEM_PROMPT", "") self._max_turns = int(os.environ.get("CLAUDE_MAX_TURNS", "20")) self._permission_mode: PermissionMode = _DEFAULT_PERMISSION_MODE - self._sessions: dict[str, ClaudeSDKClient] = {} logger.info( @@ -94,8 +78,7 @@ async def close_session(self, session_key: str) -> None: logger.warning("Error closing session %s", session_key, exc_info=True) async def close_all_sessions(self) -> None: - keys = list(self._sessions.keys()) - for key in keys: + for key in list(self._sessions): await self.close_session(key) async def run( @@ -111,7 +94,6 @@ async def run( text_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] - async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: @@ -129,9 +111,8 @@ async def run( elif isinstance(msg, ResultMessage): break - summary = "".join(text_parts) return AgentResponse( - summary=summary[:200], + summary="".join(text_parts)[:200], metadata={"tool_calls": tool_calls}, ) except Exception: diff --git a/examples/agent_service/run_agent_service.py b/examples/experimental/agent_service/claude/run_agent_service.py similarity index 74% rename from examples/agent_service/run_agent_service.py rename to examples/experimental/agent_service/claude/run_agent_service.py index e96f83f501..77aa7db577 100644 --- a/examples/agent_service/run_agent_service.py +++ b/examples/experimental/agent_service/claude/run_agent_service.py @@ -1,17 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""Launch the Agent Service with Claude Agent SDK. - -Usage:: - - python examples/agent_service/run_agent_service.py - python examples/agent_service/run_agent_service.py --num-pairs 2 - -Requires:: - - uv pip install claude-agent-sdk - export ANTHROPIC_API_KEY=sk-... -""" +"""Launch the Agent Service with Claude Agent SDK.""" from __future__ import annotations @@ -21,10 +10,8 @@ import httpx -from areal.experimental.agent_service.controller import ( - AgentServiceController, - AgentServiceControllerConfig, -) +from areal.api.cli_args import AgentConfig +from areal.experimental.agent_service.controller import AgentController async def _wait_healthy(url: str, timeout: float = 60.0) -> None: @@ -51,7 +38,7 @@ async def interactive_loop(gateway_addr: str, admin_key: str) -> None: user_input = input("You: ") except (EOFError, KeyboardInterrupt): break - if user_input.strip().lower() in ("quit", "exit", "q"): + if user_input.strip().lower() in {"quit", "exit", "q"}: break if not user_input.strip(): continue @@ -82,17 +69,8 @@ async def interactive_loop(gateway_addr: str, admin_key: str) -> None: def main() -> None: parser = argparse.ArgumentParser(description="Agent Service — Claude Agent SDK") - parser.add_argument( - "--num-pairs", - type=int, - default=1, - help="Number of Worker+DataProxy pairs (default: 1)", - ) - parser.add_argument( - "--admin-api-key", - default="areal-agent-admin", - help="Admin API key for inter-service auth", - ) + parser.add_argument("--num-pairs", type=int, default=1) + parser.add_argument("--admin-api-key", default="areal-agent-admin") args = parser.parse_args() from areal.infra.scheduler.local import LocalScheduler @@ -103,12 +81,14 @@ def main() -> None: gpu_devices=[], ) - ctrl_config = AgentServiceControllerConfig( - agent_cls_path="examples.agent_service.agent.ClaudeAgent", - admin_api_key=args.admin_api_key, - num_pairs=args.num_pairs, + ctrl = AgentController( + config=AgentConfig( + agent_cls_path="examples.experimental.agent_service.claude.agent.ClaudeAgent", + admin_api_key=args.admin_api_key, + num_pairs=args.num_pairs, + ), + scheduler=scheduler, ) - ctrl = AgentServiceController(config=ctrl_config, scheduler=scheduler) try: print(f"Initializing with {args.num_pairs} pair(s) ...") @@ -119,7 +99,6 @@ def main() -> None: asyncio.run(_wait_healthy(f"{ctrl.gateway_addr}/health")) print("All services ready.\n") - asyncio.run(interactive_loop(ctrl.gateway_addr, admin_key=args.admin_api_key)) finally: print("\nShutting down ...") diff --git a/examples/experimental/agent_service/tau2/__init__.py b/examples/experimental/agent_service/tau2/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/agent_service/tau2/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/experimental/agent_service/tau2/agent.py b/examples/experimental/agent_service/tau2/agent.py new file mode 100644 index 0000000000..f5c37ab440 --- /dev/null +++ b/examples/experimental/agent_service/tau2/agent.py @@ -0,0 +1,225 @@ +"""Tau2 agent for the experimental agent service.""" + +from __future__ import annotations + +import inspect +import json +import os +from typing import Any + +from areal.experimental.agent_service.types import ( + AgentRequest, + AgentResponse, + EventEmitter, +) +from areal.utils import logging + +logger = logging.getLogger("Tau2Agent") + + +def _make_pydantic_tool(tau2_tool: Any): + fn = tau2_tool._func # noqa: SLF001 + name = tau2_tool.name + schema = getattr(tau2_tool, "openai_schema", {}) or {} + doc = schema.get("function", {}).get("description", name) + + async def _wrapper(**kwargs: Any) -> str: + try: + result = fn(**kwargs) + except Exception as exc: + result = f"Tool error: {exc}" + if not isinstance(result, str): + result = json.dumps(result, default=str) + return result + + _wrapper.__name__ = name + _wrapper.__qualname__ = name + _wrapper.__doc__ = doc + sig = inspect.signature(fn) + _wrapper.__signature__ = inspect.Signature( + [ + inspect.Parameter( + pname, + kind=inspect.Parameter.KEYWORD_ONLY, + default=param.default, + annotation=param.annotation, + ) + for pname, param in sig.parameters.items() + ] + ) + if hasattr(fn, "__annotations__"): + _wrapper.__annotations__ = { + k: v for k, v in fn.__annotations__.items() if k != "return" + } + return _wrapper + + +def _think_tool_fn(thoughts: str) -> str: + del thoughts + return "Your thoughts are recorded. Please continue your work." + + +class Tau2Agent: + """AgentRunnable that wraps a PydanticAI agent with tau2 tools.""" + + def __init__(self, config: dict[str, Any] | None = None, **kwargs: Any) -> None: + del kwargs + try: + from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel + from pydantic_ai.providers.openai import OpenAIProvider + from tau2.environment.tool import Tool as Tau2Tool + from tau2.registry import registry + except ImportError as exc: + raise ImportError( + "Tau2 agent service example requires 'pydantic-ai' and 'tau2-bench'" + ) from exc + + config = config or {} + tau2_cfg = config.get("tau2", {}) + agent_llm_cfg = config.get("agent_llm", {}) + + self._domain = tau2_cfg.get("domain") or os.environ.get( + "TAU2_DOMAIN", "airline" + ) + add_thinking = tau2_cfg.get("add_thinking_tool", False) + + data_dir = tau2_cfg.get("data_dir") or os.environ.get("TAU2_DATA_DIR") + if data_dir: + os.environ["TAU2_DATA_DIR"] = data_dir + + env_constructor = registry.get_env_constructor(self._domain) + env = env_constructor(solo_mode=False) + tau2_tools: list[Tau2Tool] = env.get_tools() + if add_thinking: + tau2_tools.append(Tau2Tool(_think_tool_fn)) + + tools = [_make_pydantic_tool(t) for t in tau2_tools] + system_prompt = env.get_policy() + + model_name = agent_llm_cfg.get("model", "openai:default") + base_url = agent_llm_cfg.get("base_url") + api_key = agent_llm_cfg.get("api_key", "unused") + + if base_url: + model: Any = OpenAIChatModel( + model_name.replace("openai:", ""), + provider=OpenAIProvider(base_url=base_url, api_key=api_key), + ) + else: + model = model_name + + self._openai_chat_model = OpenAIChatModel + self._openai_provider = OpenAIProvider + self._agent = Agent(model, system_prompt=system_prompt, tools=tools) + logger.info( + "Tau2Agent initialized (domain=%s, tools=%d, model=%s)", + self._domain, + len(tools), + model_name, + ) + + def _resolve_model(self, metadata: dict[str, Any]) -> Any: + base_url = metadata.get("inference_base_url") + if not base_url: + return self._agent.model + model_name = metadata.get("inference_model", "default") + api_key = metadata.get("inference_api_key", "unused") + return self._openai_chat_model( + model_name, + provider=self._openai_provider(base_url=base_url, api_key=api_key), + ) + + async def run( + self, + request: AgentRequest, + *, + emitter: EventEmitter, + ) -> AgentResponse: + from pydantic_ai.messages import ( + ModelRequest, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + ) + from pydantic_ai.messages import ModelResponse as PAModelResponse + + message_history: list[ModelRequest | PAModelResponse] = [] + for msg in request.history: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "user": + message_history.append( + ModelRequest(parts=[UserPromptPart(content=content or "")]) + ) + elif role == "assistant": + tool_calls = msg.get("tool_calls") + if tool_calls: + parts = [] + for tc in tool_calls: + fn = tc.get("function", tc) + parts.append( + ToolCallPart( + tool_name=fn.get("name", ""), + args=fn.get("arguments", ""), + tool_call_id=tc.get("id", ""), + ) + ) + message_history.append(PAModelResponse(parts=parts)) + elif content: + message_history.append( + PAModelResponse(parts=[TextPart(content=content)]) + ) + elif role == "tool": + tool_call_id = msg.get("tool_call_id", "") + message_history.append( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name=tool_call_id, + content=content or "", + tool_call_id=tool_call_id, + ) + ] + ) + ) + + try: + result = await self._agent.run( + request.message, + message_history=message_history, + model=self._resolve_model(request.metadata), + ) + except Exception as exc: + logger.error("Tau2Agent turn failed: %s", exc) + await emitter.emit_delta(f"Agent error: {exc}") + return AgentResponse( + summary=f"Agent error: {exc}", metadata={"tool_calls": []} + ) + + final_text = str(result.output) if result.output else "" + tool_calls: list[dict[str, Any]] = [] + for msg in result.new_messages(): + if not hasattr(msg, "parts"): + continue + for part in msg.parts: + kind = getattr(part, "part_kind", "") + if kind == "tool-call": + name = getattr(part, "tool_name", "") + args = getattr(part, "args", "") + if isinstance(args, dict): + args = json.dumps(args) + await emitter.emit_tool_call(name=name, args=str(args)) + tool_calls.append({"name": name, "arguments": args}) + elif kind == "tool-return": + name = getattr(part, "tool_name", "") + content = str(getattr(part, "content", "")) + await emitter.emit_tool_result(name=name, result=content) + + if final_text: + await emitter.emit_delta(final_text) + + return AgentResponse( + summary=final_text[:200], metadata={"tool_calls": tool_calls} + ) diff --git a/examples/experimental/agent_service/tau2/config.yaml b/examples/experimental/agent_service/tau2/config.yaml new file mode 100644 index 0000000000..a3db01562d --- /dev/null +++ b/examples/experimental/agent_service/tau2/config.yaml @@ -0,0 +1,134 @@ +experiment_name: tau2-agent-service-rollout +trial_name: 1.7b-telecom + +seed: 1 +enable_offload: false +total_train_epochs: 1 +total_train_steps: null +tokenizer_path: ${model_path} + +model_path: Qwen/Qwen3-1.7B + +cluster: + n_nodes: 1 + n_gpus_per_node: 2 + fileroot: /path/to/experiments + name_resolve: + type: nfs + nfs_record_root: /path/to/name_resolve + +scheduler: + type: local + +gconfig: + n_samples: 1 + min_new_tokens: 0 + max_new_tokens: 8192 + max_tokens: 16384 + greedy: false + temperature: 1.0 + +rollout: + _version: v2 + backend: "sglang:d2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 32 + queue_size: null + consumer_batch_size: 8 + max_head_offpolicyness: 1000000000 + enable_rollout_tracing: true + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: false + admin_api_key: rollout-admin + openai: + mode: inline + tool_call_parser: qwen25 + reasoning_parser: qwen3 + engine_max_tokens: ${gconfig.max_tokens} + export_style: individual + turn_discount: 1.0 + admin_api_key: rollout-admin + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + cpu: 2 + mem: 16 + cmd: python3 -m areal.experimental.inference_service.guard + env_vars: + AREAL_PROXY_WARN_ONCE: "1" + +agent_service: + agent_cls_path: examples.experimental.agent_service.tau2.agent.Tau2Agent + admin_api_key: areal-agent-admin + num_pairs: 1 + setup_timeout: 120.0 + health_poll_interval: 5.0 + drain_timeout: 30.0 + log_level: info + env: {} + +econfig: + domain: telecom + max_steps: 50 + add_thinking_tool: false + solo_mode: false + user_llm_base_url: http://localhost:8000/v1/ + user_llm: openai/self-hosted-Qwen2.5-72B + user_llm_args: + temperature: 0.0 + max_completion_tokens: 512 + turn_discount: 1.0 + invalid_format_penalty: 0.1 + +sglang: + model_path: ${model_path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: bfloat16 + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +train_dataset: + batch_size: 8 + shuffle: true + pin_memory: true + num_workers: 4 + path: tau2/train + type: rl + max_length: 1024 + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled diff --git a/examples/experimental/agent_service/tau2/run_rollout.py b/examples/experimental/agent_service/tau2/run_rollout.py new file mode 100644 index 0000000000..919569438b --- /dev/null +++ b/examples/experimental/agent_service/tau2/run_rollout.py @@ -0,0 +1,208 @@ +"""Direct rollout driver for the tau2 agent-service workflow.""" + +from __future__ import annotations + +import asyncio +import os +import sys +import warnings +from copy import deepcopy +from dataclasses import asdict, dataclass, field +from typing import Any + +from datasets import Dataset + +from areal.api.alloc_mode import ModelAllocation +from areal.api.cli_args import ( + AgentConfig, + BaseExperimentConfig, + GenerationHyperparameters, + InferenceEngineConfig, + SGLangConfig, + TrainDatasetConfig, + load_expr_config, +) +from areal.experimental.agent_service.controller import AgentController +from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, +) +from areal.utils import logging + +logger = logging.getLogger("Tau2AgentServiceRollout") + + +def get_tau2_dataset(domain: str, type: str = "rl", split: str = "train") -> Dataset: + from tau2.registry import registry + + assert type == "rl", "Only RL dataset is supported for now" + splits_loader_fn = registry.get_task_splits_loader(domain) + if splits_loader_fn is None: + raise ValueError(f"No task splits loader found for domain {domain}") + splits = splits_loader_fn() + if split not in splits: + raise ValueError( + f"Split {split} not found for domain {domain}, available splits: {list(splits.keys())}" + ) + task_ids = splits[split] + dataset_items = [{"task_id": task_id, "split": split} for task_id in task_ids] + if len(dataset_items) < 128: + original_items = dataset_items.copy() + while len(dataset_items) < 128: + dataset_items.extend(original_items) + return Dataset.from_list(dataset_items) + + +@dataclass +class Tau2AgentServiceRolloutConfig(BaseExperimentConfig): + gconfig: GenerationHyperparameters = field( + default_factory=GenerationHyperparameters + ) + rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig) + model_path: str = "" + econfig: dict[str, Any] = field(default_factory=dict) + agent_service: AgentConfig = field( + default_factory=lambda: AgentConfig( + agent_cls_path="examples.experimental.agent_service.tau2.agent.Tau2Agent" + ) + ) + sglang: SGLangConfig = field(default_factory=SGLangConfig) + train_dataset: TrainDatasetConfig = field(default_factory=TrainDatasetConfig) + + +async def _run_rollouts( + workflow: Any, + controller: RolloutControllerV2, + dataloader: Any, + *, + max_batches: int | None = None, +) -> None: + batch_count = 0 + for batch_idx, batch in enumerate(dataloader): + if max_batches is not None and batch_count >= max_batches: + break + + keys = list(batch.keys()) + batch_size = len(batch[keys[0]]) + data_rows = [{k: batch[k][i] for k in keys} for i in range(batch_size)] + + results = await asyncio.gather( + *(workflow.arun_episode(controller, row) for row in data_rows) + ) + + rewards: list[float] = [] + trajectories = 0 + for result in results: + if not result: + continue + trajectories += 1 + last_id = next(reversed(result)) + reward = result[last_id].reward + if reward is not None: + rewards.append(float(reward)) + + avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 + logger.info( + "Batch %d: n_trajs=%d, rewards=%s, avg_reward=%.4f", + batch_idx, + trajectories, + rewards, + avg_reward, + ) + batch_count += 1 + + logger.info("Rollout complete (%d batches)", batch_count) + + +def main(argv: list[str]) -> None: + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + + config, _ = load_expr_config(argv, Tau2AgentServiceRolloutConfig) + rollout_cfg = deepcopy(config.rollout) + rollout_cfg.model = config.model_path + + from examples.experimental.agent_service.tau2.workflow import ( + Tau2AgentServiceWorkflow, + ) + from examples.tau2.utils import Tau2EnvConfig + + econfig = ( + Tau2EnvConfig(**config.econfig) + if isinstance(config.econfig, dict) + else config.econfig + ) + train_dataset = get_tau2_dataset( + domain=econfig.domain, + type=config.train_dataset.type, + split=config.train_dataset.path.split("/")[-1], + ) + + from torch.utils.data import DataLoader + + dataloader = DataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size, + shuffle=config.train_dataset.shuffle, + num_workers=0, + ) + + from areal.infra.scheduler.local import LocalScheduler + from areal.infra.scheduler.slurm import SlurmScheduler + + if config.scheduler.type == "local": + scheduler = LocalScheduler(exp_config=config) + elif config.scheduler.type == "slurm": + scheduler = SlurmScheduler(exp_config=config) + else: + raise NotImplementedError(f"Unknown scheduler type: {config.scheduler.type}") + + rollout_alloc = ModelAllocation.from_str(config.rollout.backend, name="rollout") + if rollout_alloc.backend == "sglang": + server_args = asdict(config.sglang) + elif rollout_alloc.backend == "vllm": + server_args = asdict(config.vllm) + else: + raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") + + rollout_controller = RolloutControllerV2(config=rollout_cfg, scheduler=scheduler) + rollout_controller.initialize(role="rollout", server_args=server_args) + + agent_controller = AgentController(config=config.agent_service, scheduler=scheduler) + agent_controller.initialize() + + workflow = Tau2AgentServiceWorkflow( + agent_controller=agent_controller, + inference_gateway_addr=rollout_controller.proxy_gateway_addr, + inference_admin_api_key=rollout_cfg.admin_api_key, + inference_model=config.model_path, + econfig=asdict(econfig), + gen_args={ + "temperature": config.gconfig.temperature, + "max_completion_tokens": config.gconfig.max_new_tokens, + }, + timeout=600.0, + discount=rollout_cfg.openai.turn_discount if rollout_cfg.openai else 1.0, + export_style=( + rollout_cfg.openai.export_style if rollout_cfg.openai else "individual" + ), + ) + + max_batches_env = os.environ.get("AREAL_MAX_BATCHES") + max_batches = int(max_batches_env) if max_batches_env is not None else None + + try: + asyncio.run( + _run_rollouts( + workflow, + rollout_controller, + dataloader, + max_batches=max_batches, + ) + ) + finally: + agent_controller.destroy() + rollout_controller.destroy() + scheduler.delete_workers(None) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/experimental/agent_service/tau2/workflow.py b/examples/experimental/agent_service/tau2/workflow.py new file mode 100644 index 0000000000..990babb9ea --- /dev/null +++ b/examples/experimental/agent_service/tau2/workflow.py @@ -0,0 +1,251 @@ +"""Tau2 workflow using the experimental agent service.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any + +from openai import AsyncOpenAI + +from areal.api.workflow_api import RolloutWorkflow +from areal.infra import workflow_context +from areal.utils import logging, stats_tracker + +if TYPE_CHECKING: + from areal.api.engine_api import InferenceEngine + from areal.experimental.agent_service.controller.controller import AgentController + from areal.experimental.openai.types import InteractionWithTokenLogpReward + +logger = logging.getLogger("Tau2AgentServiceWorkflow") + + +def _extract_response_text(response: dict[str, Any]) -> str: + parts: list[str] = [] + for item in response.get("output", []): + if item.get("type") == "message": + for block in item.get("content", []): + if block.get("type") == "output_text": + parts.append(block.get("text", "")) + return "\n".join(parts).strip() + + +def _extract_completion_text(completion: Any) -> str: + choice = completion.choices[0] + message = getattr(choice, "message", None) + content = getattr(message, "content", "") if message is not None else "" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text" and item.get("text"): + parts.append(str(item["text"])) + else: + text = getattr(item, "text", None) + if text: + parts.append(str(text)) + return "\n".join(parts).strip() + return str(content).strip() + + +class Tau2AgentServiceWorkflow(RolloutWorkflow): + def __init__( + self, + agent_controller: AgentController, + inference_gateway_addr: str, + inference_admin_api_key: str, + inference_model: str = "", + econfig: dict[str, Any] | None = None, + gen_args: dict[str, Any] | None = None, + timeout: float = 600.0, + max_turns: int = 10, + discount: float = 1.0, + export_style: str = "individual", + ) -> None: + from examples.tau2.utils import Tau2EnvConfig + + self.agent_controller = agent_controller + self.inference_gateway_addr = inference_gateway_addr.rstrip("/") + self.inference_admin_api_key = inference_admin_api_key + self.inference_model = inference_model + self.econfig = ( + Tau2EnvConfig(**econfig) + if isinstance(econfig, dict) + else (econfig or Tau2EnvConfig()) + ) + self.gen_args = gen_args or {} + self.timeout = timeout + self.max_turns = max_turns + self.discount = discount + self.export_style = export_style + + async def _run_dialog( + self, + data: dict[str, Any], + agent_session_id: str, + ) -> float: + from tau2.data_model.message import AssistantMessage, UserMessage + from tau2.data_model.simulation import SimulationRun, TerminationReason + from tau2.evaluator.evaluator import EvaluationType, evaluate_simulation + + from examples.tau2.agent import _get_task + from examples.tau2.utils import Tau2EnvConfig + + econfig = self.econfig + if "econfig" in data: + econfig = Tau2EnvConfig(**data["econfig"]) + + task = _get_task( + domain=econfig.domain, + task_id=data["task_id"], + split=data.get("split", "train"), + ) + first_user_message = str(data.get("prompt") or task.user_scenario).strip() + if not first_user_message: + raise ValueError("data.prompt or task.user_scenario is required") + + user_client = None + if not econfig.solo_mode: + if not econfig.user_llm_base_url: + raise ValueError( + "econfig.user_llm_base_url is required when solo_mode is false" + ) + user_client = AsyncOpenAI( + base_url=econfig.user_llm_base_url, + api_key="dummy", + max_retries=3, + timeout=120.0, + ) + + tau2_messages: list[UserMessage | AssistantMessage] = [] + chat_history: list[dict[str, str]] = [ + {"role": "user", "content": first_user_message} + ] + next_user_message = first_user_message + + for turn_idx in range(self.max_turns): + response = await self.agent_controller.step( + next_user_message, agent_session_id + ) + agent_text = _extract_response_text(response) or "(no response)" + + tau2_messages.append( + UserMessage( + role="user", + content=next_user_message, + turn_idx=len(tau2_messages), + ) + ) + tau2_messages.append( + AssistantMessage( + role="assistant", + content=agent_text, + turn_idx=len(tau2_messages), + ) + ) + + if turn_idx + 1 >= self.max_turns or user_client is None: + break + + chat_history.append({"role": "assistant", "content": agent_text}) + completion = await user_client.chat.completions.create( + model=econfig.user_llm or "dummy", + messages=[ + { + "role": "system", + "content": ( + "You are simulating the tau2 user described below. " + "Respond with the user's next message only, in one turn, " + "based on the conversation so far.\n\n" + f"User scenario:\n{task.user_scenario}" + ), + }, + *chat_history, + ], + **(econfig.user_llm_args or {}), + ) + next_user_message = _extract_completion_text(completion) + if not next_user_message: + break + chat_history.append({"role": "user", "content": next_user_message}) + + simulation = SimulationRun( + id=f"agent-svc-{task.id}", + task_id=task.id, + messages=tau2_messages, + start_time="", + end_time="", + duration=0.0, + termination_reason=TerminationReason.USER_STOP, + ) + reward_info = evaluate_simulation( + simulation=simulation, + task=task, + evaluation_type=EvaluationType.ALL, + solo_mode=econfig.solo_mode, + domain=econfig.domain, + ) + return float(reward_info.reward) + + async def arun_episode( + self, + engine: InferenceEngine, + data: dict[str, Any], + ) -> dict[str, InteractionWithTokenLogpReward] | None: + del engine + task_id = str(data.get("task_id") or workflow_context.get().task_id) + session = await self.agent_controller.start_session( + task_id=task_id, + inference_gateway_addr=self.inference_gateway_addr, + inference_admin_api_key=self.inference_admin_api_key, + inference_model=self.inference_model, + ) + + trajectory_id: int | None = None + finished = False + try: + reward = await asyncio.wait_for( + self._run_dialog(data, session["session_id"]), + timeout=self.timeout, + ) + reward_result = await self.agent_controller.set_reward( + reward, + session["session_id"], + ) + raw_trajectory_id = reward_result.get("trajectory_id") + trajectory_id = ( + int(raw_trajectory_id) if raw_trajectory_id is not None else None + ) + finished = True + except Exception: + logger.warning( + "Tau2 agent-service task failed. This trajectory will be rejected." + ) + if not finished: + try: + await self.agent_controller.set_reward(0.0, session["session_id"]) + except Exception: + logger.warning( + "Failed to finish session %s after workflow failure", + session["session_id"], + ) + raise + + interactions = await self.agent_controller.export_trajectory( + session["session_id"], + trajectory_id=trajectory_id, + discount=self.discount, + style=self.export_style, + ) + if not interactions: + logger.warning( + "Session %s has no interactions, trajectory will be rejected.", + session["session_id"], + ) + return None + + last_id = next(reversed(interactions)) + last_reward = interactions[last_id].reward + stats_tracker.get(workflow_context.stat_scope()).scalar(reward=last_reward) + return interactions diff --git a/examples/experimental/inference_service/README.md b/examples/experimental/inference_service/README.md index 1b7c0af21f..a039974b7b 100644 --- a/examples/experimental/inference_service/README.md +++ b/examples/experimental/inference_service/README.md @@ -1,7 +1,7 @@ # AReaL Inference Service Examples This directory contains two examples that use the AReaL Inference Service -(`GatewayInferenceController`) — an experimental rollout backend that exposes an +(`RolloutControllerV2`) — an experimental rollout backend that exposes an OpenAI-compatible proxy gateway so any external agent runtime can submit chat requests and receive RL training data. diff --git a/examples/experimental/inference_service/human_in_the_loop_demo.py b/examples/experimental/inference_service/human_in_the_loop_demo.py index 96f08c00a8..a6e833512d 100644 --- a/examples/experimental/inference_service/human_in_the_loop_demo.py +++ b/examples/experimental/inference_service/human_in_the_loop_demo.py @@ -283,7 +283,8 @@ def cleanup(signum=None, frame=None): str(config_yaml), f"actor.path={args.actor_path}", f"rollout.backend={args.inference_backend}:d1", - f"rollout.openai.admin_api_key={args.admin_key}", + f"rollout.admin_api_key={args.admin_key}", + "rollout._version=v2", f"rollout.request_timeout={args.request_timeout}", ] if args.api_url: diff --git a/examples/experimental/inference_service/online_rollout.py b/examples/experimental/inference_service/online_rollout.py index 106f8e86d0..dd0bfd1bdb 100644 --- a/examples/experimental/inference_service/online_rollout.py +++ b/examples/experimental/inference_service/online_rollout.py @@ -4,6 +4,7 @@ import argparse import sys +from copy import deepcopy from dataclasses import asdict from pathlib import Path @@ -20,11 +21,8 @@ def main(args: list[str]) -> None: ext_args, remaining = parser.parse_known_args(args) from areal.api.cli_args import PPOConfig, load_expr_config - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.utils import logging from areal.utils.environ import is_single_controller @@ -54,26 +52,8 @@ def main(args: list[str]) -> None: is_external = ext_args.api_url is not None - ctrl_config = GatewayControllerConfig( - tokenizer_path=config.tokenizer_path, - model_path=config.actor.path, - consumer_batch_size=config.rollout.consumer_batch_size, - max_concurrent_rollouts=config.rollout.max_concurrent_rollouts, - max_head_offpolicyness=config.rollout.max_head_offpolicyness, - queue_size=config.rollout.queue_size, - enable_rollout_tracing=config.rollout.enable_rollout_tracing, - fileroot=config.rollout.fileroot, - experiment_name=config.rollout.experiment_name, - trial_name=config.rollout.trial_name, - dump_to_file=False, - backend=config.rollout.backend, - scheduling_spec=config.rollout.scheduling_spec, - setup_timeout=config.rollout.setup_timeout, - request_timeout=config.rollout.request_timeout, - admin_api_key=openai_cfg.admin_api_key, - turn_discount=openai_cfg.turn_discount, - export_style=openai_cfg.export_style, - ) + ctrl_config = deepcopy(config.rollout) + ctrl_config.dump_to_file = False if ext_args.model: ctrl_config.model = ext_args.model if is_external: @@ -91,7 +71,7 @@ def main(args: list[str]) -> None: else: raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") - ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) + ctrl = RolloutControllerV2(config=ctrl_config, scheduler=scheduler) try: ctrl.initialize( role="rollout", diff --git a/examples/experimental/inference_service/online_rollout.yaml b/examples/experimental/inference_service/online_rollout.yaml index 307afdfa32..923a16da9e 100644 --- a/examples/experimental/inference_service/online_rollout.yaml +++ b/examples/experimental/inference_service/online_rollout.yaml @@ -19,6 +19,7 @@ scheduler: type: local rollout: + _version: v2 backend: "sglang:d1" experiment_name: ${experiment_name} trial_name: ${trial_name} @@ -33,6 +34,7 @@ rollout: dump_to_file: true request_timeout: 120.0 setup_timeout: 300.0 + admin_api_key: sk-test123456 openai: mode: online export_style: individual diff --git a/examples/experimental/inference_service/tau2_rollout.py b/examples/experimental/inference_service/tau2_rollout.py index d7c6828216..b78701b023 100644 --- a/examples/experimental/inference_service/tau2_rollout.py +++ b/examples/experimental/inference_service/tau2_rollout.py @@ -1,4 +1,4 @@ -"""Rollout-only script for Tau2 benchmark using GatewayInferenceController. +"""Rollout-only script for Tau2 benchmark using RolloutControllerV2. This example demonstrates how to run rollouts (data generation) without training, using the gateway HTTP stack to route inference requests. @@ -13,6 +13,7 @@ import sys import warnings +from copy import deepcopy from dataclasses import asdict, dataclass, field from typing import Any @@ -27,11 +28,8 @@ TrainDatasetConfig, load_expr_config, ) -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.utils import logging @@ -138,7 +136,7 @@ def get_tau2_dataset( @dataclass class Tau2GatewayRolloutConfig(BaseExperimentConfig): - """Configuration for Tau2 rollout-only with GatewayInferenceController.""" + """Configuration for Tau2 rollout-only with RolloutControllerV2.""" gconfig: GenerationHyperparameters = field( default_factory=GenerationHyperparameters @@ -160,7 +158,8 @@ def main(argv: list[str]) -> None: config, _ = load_expr_config(argv, Tau2GatewayRolloutConfig) econfig = config.econfig - rollout_cfg = config.rollout + rollout_cfg = deepcopy(config.rollout) + rollout_cfg.model = config.model_path # --- Dataset --- train_dataset = get_tau2_dataset( @@ -178,26 +177,6 @@ def main(argv: list[str]) -> None: num_workers=0, # in-process; tau2 dataset is lightweight ) - # --- Build GatewayControllerConfig from YAML rollout section --- - ctrl_config = GatewayControllerConfig( - tokenizer_path=config.tokenizer_path, - model_path=config.model_path, - consumer_batch_size=rollout_cfg.consumer_batch_size, - max_concurrent_rollouts=rollout_cfg.max_concurrent_rollouts, - max_head_offpolicyness=rollout_cfg.max_head_offpolicyness, - queue_size=rollout_cfg.queue_size, - enable_rollout_tracing=rollout_cfg.enable_rollout_tracing, - fileroot=rollout_cfg.fileroot, - experiment_name=rollout_cfg.experiment_name, - trial_name=rollout_cfg.trial_name, - dump_to_file=rollout_cfg.dump_to_file, - backend=rollout_cfg.backend, - scheduling_spec=rollout_cfg.scheduling_spec, - setup_timeout=rollout_cfg.setup_timeout, - request_timeout=rollout_cfg.request_timeout, - openai=rollout_cfg.openai, - ) - # --- Scheduler --- from areal.infra.scheduler.local import LocalScheduler from areal.infra.scheduler.slurm import SlurmScheduler @@ -219,7 +198,7 @@ def main(argv: list[str]) -> None: else: raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") - ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) + ctrl = RolloutControllerV2(config=rollout_cfg, scheduler=scheduler) ctrl.initialize( role="rollout", server_args=server_args, diff --git a/examples/experimental/inference_service/tau2_rollout.yaml b/examples/experimental/inference_service/tau2_rollout.yaml index 13b19a5904..8fd69c6191 100644 --- a/examples/experimental/inference_service/tau2_rollout.yaml +++ b/examples/experimental/inference_service/tau2_rollout.yaml @@ -29,6 +29,7 @@ gconfig: temperature: 1.0 rollout: + _version: v2 backend: "sglang:d2" experiment_name: ${experiment_name} trial_name: ${trial_name} @@ -40,6 +41,7 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: false + admin_api_key: rollout-admin openai: mode: inline tool_call_parser: qwen25 diff --git a/tests/experimental/agent_service/test_controller.py b/tests/experimental/agent_service/test_controller.py index 376bed71c6..8a36a24540 100644 --- a/tests/experimental/agent_service/test_controller.py +++ b/tests/experimental/agent_service/test_controller.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for AgentServiceController. +"""Unit tests for AgentController. All Guard HTTP interactions are mocked — no real processes or servers. Tests cover: initialize, destroy, scale_up, scale_down, and error handling. @@ -10,15 +10,14 @@ import time from dataclasses import dataclass -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from areal.experimental.agent_service.controller.config import ( - AgentServiceControllerConfig, -) +from areal.api.cli_args import AgentConfig from areal.experimental.agent_service.controller.controller import ( - AgentServiceController, + AgentController, + _RuntimeSession, ) CTRL = "areal.experimental.agent_service.controller.controller" @@ -83,7 +82,7 @@ def _mock_health_response(active_sessions: int = 0) -> MagicMock: @pytest.fixture() def config(): - return AgentServiceControllerConfig( + return AgentConfig( agent_cls_path="my.Agent", admin_api_key="test-key", num_pairs=2, @@ -116,7 +115,7 @@ def mock_post(url, **kwargs): class TestConstruction: def test_construction(self, config): scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) assert ctrl.router_addr == "" assert ctrl.gateway_addr == "" assert ctrl.pairs == {} @@ -129,7 +128,7 @@ def test_initialize_forks_router_pairs_gateway(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090"), ("10.0.0.2", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() scheduler.create_workers.assert_called_once() @@ -148,7 +147,7 @@ def test_scale_up_adds_pairs(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl.pairs) == 0 @@ -178,7 +177,7 @@ def mock_post(url, **kwargs): mock_requests.RequestException = Exception scheduler = _make_scheduler(("g0", "8090"), ("g1", "8091")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() guards_called.clear() @@ -197,7 +196,7 @@ def test_scale_down_removes_newest_first(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl.pairs) == 3 @@ -214,7 +213,7 @@ def test_destroy_clears_everything(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl._forked_services) > 0 @@ -246,7 +245,7 @@ def mock_post(url, **kwargs): mock_requests.RequestException = Exception scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() ctrl.destroy() @@ -274,7 +273,7 @@ def mock_get(url, **kwargs): mock_requests.get = mock_get scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() health_call_count = 0 @@ -301,7 +300,7 @@ def counting_get(url, **kwargs): mock_requests.get = counting_get scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() pre_get_count = get_count @@ -318,7 +317,7 @@ def test_health_monitor_starts_and_stops(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert ctrl._health_thread is not None assert ctrl._health_thread.is_alive() @@ -333,8 +332,140 @@ def test_health_monitor_disabled_when_interval_zero(self, mock_requests, config) _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert ctrl._health_thread is None ctrl.destroy() + + +class TestRuntimeAPIs: + @pytest.mark.asyncio + async def test_start_session_grants_capacity_and_stores_session(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._grant_capacity = AsyncMock() + ctrl._post_json = AsyncMock( + return_value={"session_id": "inf-sess-1", "api_key": "sess-key"} + ) + + session = await ctrl.start_session( + "task-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_model="Qwen/Test", + ) + + assert session["session_id"].startswith("agent-sess-") + assert session["inference_session_id"] == "inf-sess-1" + assert session["api_key"] == "sess-key" + ctrl._grant_capacity.assert_awaited_once_with( + "http://inference", "rollout-admin" + ) + ctrl._post_json.assert_awaited_once_with( + "http://inference/rl/start_session", + payload={"task_id": "task-1"}, + headers={"Authorization": "Bearer rollout-admin"}, + ) + + stored = ctrl._resolve_session(session["session_id"]) + assert stored.inference_session_id == "inf-sess-1" + assert stored.inference_session_api_key == "sess-key" + assert stored.inference_model == "Qwen/Test" + + @pytest.mark.asyncio + async def test_step_posts_async_gateway_request(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._gateway_addr = "http://agent-gateway" + ctrl._post_json = AsyncMock(return_value={"status": "completed"}) + ctrl._sessions["agent-sess-1"] = _RuntimeSession( + agent_session_id="agent-sess-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_session_id="inf-sess-1", + inference_session_api_key="sess-key", + inference_model="Qwen/Test", + ) + + result = await ctrl.step( + "hello", + "agent-sess-1", + metadata={"extra": "value"}, + ) + + assert result == {"status": "completed"} + ctrl._post_json.assert_awaited_once_with( + "http://agent-gateway/v1/responses", + payload={ + "input": [{"type": "message", "content": "hello"}], + "model": "Qwen--Test", + "user": "agent-sess-1", + "metadata": { + "inference_base_url": "http://inference", + "inference_api_key": "sess-key", + "inference_model": "Qwen/Test", + "extra": "value", + }, + }, + headers={"Authorization": "Bearer test-key"}, + ) + + @pytest.mark.asyncio + async def test_set_reward_uses_session_api_key(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._post_json = AsyncMock(return_value={"trajectory_id": 7}) + ctrl._sessions["agent-sess-1"] = _RuntimeSession( + agent_session_id="agent-sess-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_session_id="inf-sess-1", + inference_session_api_key="sess-key", + ) + + result = await ctrl.set_reward(1.0, "agent-sess-1", interaction_id="resp-1") + + assert result == {"trajectory_id": 7} + ctrl._post_json.assert_awaited_once_with( + "http://inference/rl/set_reward", + payload={"interaction_id": "resp-1", "reward": 1.0}, + headers={"Authorization": "Bearer sess-key"}, + ) + + @pytest.mark.asyncio + @patch(f"{CTRL}.deserialize_interactions") + async def test_export_trajectory_deserializes_response( + self, mock_deserialize, config + ): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._post_json = AsyncMock(return_value={"interactions": {"k": "v"}}) + ctrl._sessions["agent-sess-1"] = _RuntimeSession( + agent_session_id="agent-sess-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_session_id="inf-sess-1", + inference_session_api_key="sess-key", + ) + mock_deserialize.return_value = {"interaction-1": MagicMock(reward=1.0)} + + result = await ctrl.export_trajectory( + "agent-sess-1", + trajectory_id=5, + discount=0.9, + style="individual", + ) + + assert "interaction-1" in result + ctrl._post_json.assert_awaited_once_with( + "http://inference/export_trajectories", + payload={ + "session_id": "inf-sess-1", + "trajectory_id": 5, + "discount": 0.9, + "style": "individual", + }, + headers={"Authorization": "Bearer rollout-admin"}, + ) + mock_deserialize.assert_called_once_with({"k": "v"}) diff --git a/tests/experimental/agent_service/test_guard.py b/tests/experimental/agent_service/test_guard.py index e5a82deac3..674808d618 100644 --- a/tests/experimental/agent_service/test_guard.py +++ b/tests/experimental/agent_service/test_guard.py @@ -2,7 +2,7 @@ Tests that the base guard routes are available on the agent guard app. The agent_blueprint has been removed in v2 — all orchestration logic -now lives in AgentServiceController. +now lives in AgentController. Test structure mirrors ``tests/experimental/inference_service/test_guard.py``. """ diff --git a/tests/experimental/agent_service/test_tau2_workflow.py b/tests/experimental/agent_service/test_tau2_workflow.py new file mode 100644 index 0000000000..a3daad4d6f --- /dev/null +++ b/tests/experimental/agent_service/test_tau2_workflow.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from examples.experimental.agent_service.tau2.workflow import Tau2AgentServiceWorkflow + + +@pytest.mark.asyncio +@patch("examples.experimental.agent_service.tau2.workflow.workflow_context") +@patch("examples.experimental.agent_service.tau2.workflow.stats_tracker") +async def test_arun_episode_returns_exported_interactions( + mock_stats_tracker, + mock_workflow_context, +): + controller = MagicMock() + controller.start_session = AsyncMock( + return_value={ + "session_id": "agent-sess-1", + "inference_session_id": "inf-sess-1", + "api_key": "sess-key", + } + ) + controller.set_reward = AsyncMock(return_value={"trajectory_id": 3}) + exported = {"last": SimpleNamespace(reward=1.0)} + controller.export_trajectory = AsyncMock(return_value=exported) + + workflow = Tau2AgentServiceWorkflow( + agent_controller=controller, + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_model="Qwen/Test", + econfig={"domain": "telecom", "solo_mode": True}, + timeout=10.0, + ) + workflow._run_dialog = AsyncMock(return_value=1.0) + mock_scope = object() + mock_workflow_context.get.return_value = SimpleNamespace( + task_id="task-from-context" + ) + mock_workflow_context.stat_scope.return_value = mock_scope + mock_stats = MagicMock() + mock_stats_tracker.get.return_value = mock_stats + + result = await workflow.arun_episode(engine=object(), data={"task_id": "task-1"}) + + assert result is exported + controller.start_session.assert_awaited_once_with( + task_id="task-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_model="Qwen/Test", + ) + workflow._run_dialog.assert_awaited_once_with({"task_id": "task-1"}, "agent-sess-1") + controller.set_reward.assert_awaited_once_with(1.0, "agent-sess-1") + controller.export_trajectory.assert_awaited_once_with( + "agent-sess-1", + trajectory_id=3, + discount=1.0, + style="individual", + ) + mock_stats.scalar.assert_called_once_with(reward=1.0) diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 5e61f768f1..25b0449c58 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -1,4 +1,4 @@ -"""Tests for GatewayInferenceController.""" +"""Tests for RolloutControllerV2.""" from __future__ import annotations @@ -8,34 +8,33 @@ import httpx import pytest -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.inference_service.controller.workflow import ( InferenceServiceWorkflow, ) # ============================================================================= -# GatewayControllerConfig +# InferenceEngineConfig # ============================================================================= -class TestGatewayControllerConfig: +class TestInferenceEngineConfigForInferenceService: def test_defaults(self): - cfg = GatewayControllerConfig() - assert cfg.admin_api_key is None + cfg = InferenceEngineConfig(backend="sglang:d1") + assert cfg.admin_api_key == "areal-admin-key" assert cfg.model == "default" - assert cfg.consumer_batch_size == 16 + assert cfg.consumer_batch_size == 1 assert cfg.max_concurrent_rollouts is None assert cfg.max_head_offpolicyness == 0 assert cfg.enable_rollout_tracing is False assert cfg.set_reward_finish_timeout == 0.0 def test_custom_values(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="custom-key", consumer_batch_size=32, max_concurrent_rollouts=64, @@ -49,7 +48,8 @@ def test_custom_values(self): assert cfg.set_reward_finish_timeout == 3.0 def test_scheduling_fields(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", request_timeout=60.0, setup_timeout=600.0, ) @@ -57,30 +57,31 @@ def test_scheduling_fields(self): assert cfg.setup_timeout == 600.0 def test_dump_to_file_defaults_to_false(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") assert cfg.dump_to_file is False # ============================================================================= -# GatewayInferenceController — workflow resolution helpers +# RolloutControllerV2 — workflow resolution helpers # ============================================================================= class TestControllerWorkflowResolution: def test_resolve_workflow_with_instance(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) with pytest.raises(TypeError, match=r"callable run\(\) method"): controller._resolve_workflow(12345) def test_resolve_workflow_none_creates_online_inference_service_workflow(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" resolved = controller._resolve_workflow( @@ -94,11 +95,12 @@ def test_resolve_workflow_none_creates_online_inference_service_workflow(self): assert resolved.timeout == 3.0 def test_resolve_workflow_agent_class_creates_offline_workflow(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" class MockAgent: @@ -115,17 +117,17 @@ async def run(self, data, **kwargs): assert isinstance(resolved.agent, MockAgent) def test_resolve_should_accept_fn_none(self): - assert GatewayInferenceController._resolve_should_accept_fn(None) is None + assert RolloutControllerV2._resolve_should_accept_fn(None) is None def test_resolve_should_accept_fn_callable(self): fn = lambda x: True # noqa: E731 - assert GatewayInferenceController._resolve_should_accept_fn(fn) is fn + assert RolloutControllerV2._resolve_should_accept_fn(fn) is fn def test_resolve_workflow_with_agent_class(self): """Test _resolve_workflow wraps agent-like classes in InferenceServiceWorkflow.""" - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" class MockAgent: @@ -141,8 +143,8 @@ async def run(self, data, **kwargs): assert hasattr(resolved, "arun_episode") def test_resolve_workflow_agent_class_without_gateway_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) @@ -154,8 +156,8 @@ async def run(self, data, **kwargs): controller._resolve_workflow(MockAgent, workflow_kwargs={}) def test_resolve_workflow_rollout_workflow_instance_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -172,8 +174,8 @@ def test_resolve_workflow_rollout_workflow_instance_raises(self): controller._resolve_workflow(workflow) def test_resolve_workflow_rollout_workflow_class_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -188,11 +190,11 @@ def test_resolve_workflow_rollout_workflow_class_raises(self): # ============================================================================= -# GatewayInferenceController — API surface +# RolloutControllerV2 — API surface # ============================================================================= -class TestGatewayInferenceControllerAPISurface: +class TestRolloutControllerV2APISurface: def test_has_all_public_methods(self): methods = [ "initialize", @@ -216,7 +218,7 @@ def test_has_all_public_methods(self): "start_proxy_gateway", ] for m in methods: - assert hasattr(GatewayInferenceController, m), f"Missing method: {m}" + assert hasattr(RolloutControllerV2, m), f"Missing method: {m}" def test_has_properties(self): properties = [ @@ -228,35 +230,38 @@ def test_has_properties(self): "worker_ids", ] for p in properties: - assert hasattr(GatewayInferenceController, p), f"Missing property: {p}" + assert hasattr(RolloutControllerV2, p), f"Missing property: {p}" def test_not_subclass_of_rollout_controller(self): - """GatewayInferenceController must NOT be a subclass of RolloutController.""" + """RolloutControllerV2 must NOT be a subclass of RolloutController.""" # Verify it doesn't inherit from any class except object - bases = GatewayInferenceController.__bases__ + bases = RolloutControllerV2.__bases__ assert bases == (object,), f"Unexpected bases: {bases}" # ============================================================================= -# GatewayInferenceController — construction + state +# RolloutControllerV2 — construction + state # ============================================================================= -class TestGatewayInferenceControllerConstruction: +class TestRolloutControllerV2Construction: def test_admin_api_key_none_raises(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") + cfg.admin_api_key = "" with pytest.raises(ValueError, match="admin_api_key must be set"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_model_empty_raises(self): - cfg = GatewayControllerConfig(admin_api_key="test-key", model="") + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-key", model="" + ) with pytest.raises(ValueError, match="model must not be empty"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_constructor(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) assert controller.config is cfg assert controller.scheduler is scheduler @@ -268,63 +273,63 @@ def test_constructor(self): assert controller.worker_ids == {} def test_admin_api_key_defaults(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) assert controller.config.admin_api_key == "test-key" def test_version_management_without_services(self): """set_version / get_version work even without gateway services.""" - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # No gateway services started, but version management is local controller._version = 42 assert controller.get_version() == 42 def test_export_stats_returns_dict(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) stats = controller.export_stats() assert isinstance(stats, dict) def test_start_proxy_is_noop(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Should not raise controller.start_proxy() controller.start_proxy_gateway() def test_proxy_gateway_addr(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Before initialize, proxy_gateway_addr returns the empty _gateway_addr assert controller.proxy_gateway_addr == "" def test_callback_addr_formats_ipv6_hostport(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "2001:db8::10" controller._callback_port = 19000 assert controller.callback_addr == "[2001:db8::10]:19000" def test_workflow_executor_raises_before_init(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) with pytest.raises(RuntimeError, match="initialize"): _ = controller.workflow_executor def test_config_perf_tracer_is_noop(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Should not raise controller.config_perf_tracer() controller.save_perf_tracer() @@ -343,7 +348,8 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy scheduler = MagicMock() scheduler.get_workers.return_value = [worker] - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", tokenizer_path="mock-tokenizer", request_timeout=15.0, set_reward_finish_timeout=7.5, @@ -357,7 +363,7 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy ), admin_api_key="test-admin-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 @@ -385,15 +391,15 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy # ============================================================================= -# GatewayInferenceController — gateway HTTP helpers +# RolloutControllerV2 — gateway HTTP helpers # ============================================================================= -class TestGatewayInferenceControllerHTTP: +class TestRolloutControllerV2HTTP: def test_gateway_http_post_raises_on_failure(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://127.0.0.1:19999" with pytest.raises(RuntimeError, match="Failed to POST"): controller._gateway_http_post("/test", {"key": "value"}) @@ -404,11 +410,12 @@ def test_gateway_http_post_sends_auth(self, mock_post): mock_resp.status_code = 200 mock_post.return_value = mock_resp - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="my-secret-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://127.0.0.1:8080" controller._gateway_http_post("/test_endpoint", {"data": 1}) @@ -422,11 +429,12 @@ def test_gateway_http_post_sends_auth(self, mock_post): class TestOnlineCallbackFlow: @pytest.mark.asyncio async def test_online_callback_without_waiter_buffers_export_request(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() try: async with httpx.AsyncClient() as client: @@ -443,11 +451,12 @@ async def test_online_callback_without_waiter_buffers_export_request(self): @pytest.mark.asyncio async def test_online_callback_settles_waiter_once(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -470,11 +479,12 @@ async def test_online_callback_settles_waiter_once(self): @pytest.mark.asyncio async def test_online_callback_invalid_payload_keeps_waiter_pending(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -497,11 +507,12 @@ async def test_online_callback_invalid_payload_keeps_waiter_pending(self): @pytest.mark.asyncio async def test_cancelled_waiter_buffers_completed_online_result(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -644,38 +655,39 @@ async def run(self, data, **kwargs): class TestMultiNodeConfig: def test_n_gpus_per_node_default_is_none(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") assert cfg.n_gpus_per_node is None def test_n_gpus_per_node_custom(self): - cfg = GatewayControllerConfig(n_gpus_per_node=4) + cfg = InferenceEngineConfig(backend="sglang:d1", n_gpus_per_node=4) assert cfg.n_gpus_per_node == 4 def test_n_gpus_per_node_zero_raises(self): - cfg = GatewayControllerConfig( - n_gpus_per_node=0, backend="sglang:d1t8", admin_api_key="test-key" + cfg = InferenceEngineConfig( + n_gpus_per_node=1, backend="sglang:d1t8", admin_api_key="test-key" ) + cfg.n_gpus_per_node = 0 with pytest.raises(ValueError, match="n_gpus_per_node must be >= 1"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_gpus_not_divisible_raises(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( n_gpus_per_node=3, backend="sglang:d1t8", admin_api_key="test-key" ) with pytest.raises(ValueError, match="must be divisible by n_gpus_per_node"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_single_node_backward_compat(self): - cfg = GatewayControllerConfig(backend="sglang:d2t4", admin_api_key="test-key") - controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + cfg = InferenceEngineConfig(backend="sglang:d2t4", admin_api_key="test-key") + controller = RolloutControllerV2(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 1 def test_multi_node_valid_config(self): # tp=16, n_gpus_per_node=8 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( n_gpus_per_node=8, backend="sglang:d1t16", admin_api_key="test-key" ) - controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + controller = RolloutControllerV2(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 2 @pytest.mark.asyncio @@ -698,14 +710,14 @@ async def test_async_initialize_multinode_worker_count(self): scheduler.get_workers.return_value = [worker0] # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( tokenizer_path="mock-tokenizer", backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), admin_api_key="test-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 @@ -756,14 +768,14 @@ async def test_async_initialize_multinode_fork_path(self): scheduler.get_workers.return_value = [worker0, worker1] # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( tokenizer_path="mock-tokenizer", backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), admin_api_key="test-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 diff --git a/tests/experimental/inference_service/test_controller_integration.py b/tests/experimental/inference_service/test_controller_integration.py index cdaa0b1868..380694ddb9 100644 --- a/tests/experimental/inference_service/test_controller_integration.py +++ b/tests/experimental/inference_service/test_controller_integration.py @@ -1,4 +1,4 @@ -"""Integration tests for GatewayInferenceController with real SGLang servers. +"""Integration tests for RolloutControllerV2 with real SGLang servers. Requires GPU and a model. Marked @pytest.mark.slow to exclude from default CI. Run manually: @@ -6,8 +6,8 @@ The test launches: 1. A real SGLang server (GPU subprocess) - 2. Module-scoped LocalScheduler / GatewayInferenceController fixtures - 3. A GatewayInferenceController that spins up Gateway, Router, and Data Proxy +2. Module-scoped LocalScheduler / RolloutControllerV2 fixtures +3. A RolloutControllerV2 that spins up Gateway, Router, and Data Proxy micro-services in background threads. """ @@ -121,14 +121,12 @@ def _make_gateway_controller_config( online_mode: bool = False, set_reward_finish_timeout: float = 0.0, ): - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec - return GatewayControllerConfig( + return InferenceEngineConfig( + backend="sglang:d1", tokenizer_path=model_path, - model_path=model_path, + model=model_path, set_reward_finish_timeout=set_reward_finish_timeout, scheduling_spec=( SchedulingSpec( @@ -189,17 +187,17 @@ def _export_trajectory_with_retry( @pytest.fixture(scope="module") def gateway_controller(sglang_server, model_path, tmp_path_factory): - """Create and initialize a GatewayInferenceController, yield it, then destroy.""" + """Create and initialize a RolloutControllerV2, yield it, then destroy.""" if not has_gpu(): pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler(tmp_path_factory, "gateway_controller") config = _make_gateway_controller_config(model_path) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout", @@ -219,14 +217,14 @@ def gateway_controller_online(sglang_server, model_path, tmp_path_factory): pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_online" ) config = _make_gateway_controller_config(model_path, online_mode=True) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout", @@ -246,7 +244,7 @@ def gateway_controller_with_reward_timeout(sglang_server, model_path, tmp_path_f pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler( @@ -256,7 +254,7 @@ def gateway_controller_with_reward_timeout(sglang_server, model_path, tmp_path_f model_path, set_reward_finish_timeout=3.0, ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-timeout", @@ -777,7 +775,7 @@ def test_offline_export_applies_discount_after_multiple_rewards_in_same_trajecto @pytest.fixture(scope="module") def gateway_controller_full_init(model_path, tmp_path_factory): - """Create a GatewayInferenceController that launches SGLang via the full init path. + """Create a RolloutControllerV2 that launches SGLang via the full init path. Unlike ``gateway_controller`` which passes pre-existing ``server_infos``, this fixture lets the controller create RPC workers, create @@ -786,17 +784,14 @@ def gateway_controller_full_init(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, + model=model_path, backend="sglang:d1", scheduling_spec=( SchedulingSpec( @@ -810,6 +805,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): ) server_args = { + "model_path": model_path, "skip_tokenizer_init": True, "mem_fraction_static": 0.15, } @@ -817,7 +813,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-full", server_args=server_args, @@ -1075,17 +1071,14 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, + model=model_path, backend="vllm:d1", scheduling_spec=( SchedulingSpec( @@ -1099,13 +1092,14 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): ) server_args = { + "model": model_path, "gpu_memory_utilization": 0.15, } local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vllm" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vllm", server_args=server_args, @@ -1280,17 +1274,14 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=vlm_model_path, - model_path=vlm_model_path, + model=vlm_model_path, backend="sglang:d1", scheduling_spec=( SchedulingSpec( @@ -1306,10 +1297,14 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vlm_sglang" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vlm-sglang", - server_args={"skip_tokenizer_init": True, "mem_fraction_static": 0.25}, + server_args={ + "model_path": vlm_model_path, + "skip_tokenizer_init": True, + "mem_fraction_static": 0.25, + }, ) try: @@ -1324,17 +1319,14 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=vlm_model_path, - model_path=vlm_model_path, + model=vlm_model_path, backend="vllm:d1", scheduling_spec=( SchedulingSpec( @@ -1350,7 +1342,7 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vlm_vllm" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vlm-vllm", server_args={"gpu_memory_utilization": 0.25}, diff --git a/tests/experimental/inference_service/test_controller_version.py b/tests/experimental/inference_service/test_controller_version.py index b2d90c3a5d..255bc6319a 100644 --- a/tests/experimental/inference_service/test_controller_version.py +++ b/tests/experimental/inference_service/test_controller_version.py @@ -1,4 +1,4 @@ -"""Unit tests for GatewayInferenceController version management. +"""Unit tests for RolloutControllerV2 version management. Tests set_version and get_version with mocked HTTP calls. """ @@ -8,12 +8,9 @@ import asyncio from unittest.mock import AsyncMock, MagicMock, patch -from areal.api.cli_args import SchedulingSpec -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) # ============================================================================= @@ -25,17 +22,18 @@ def _make_controller( gateway_addr: str = "", worker_ids: dict[str, str] | None = None, version: int = 0, -) -> GatewayInferenceController: +) -> RolloutControllerV2: """Create a controller with minimal config and manually injected state. Does NOT call initialize() — internal fields are set directly. """ - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-key", scheduling_spec=(SchedulingSpec(),), ) scheduler = MagicMock() - ctrl = GatewayInferenceController(config=cfg, scheduler=scheduler) + ctrl = RolloutControllerV2(config=cfg, scheduler=scheduler) ctrl._gateway_addr = gateway_addr ctrl._worker_ids = worker_ids or {} ctrl._version = version @@ -48,7 +46,7 @@ def _make_controller( class TestControllerSetVersion: - """Test GatewayInferenceController.set_version.""" + """Test RolloutControllerV2.set_version.""" def test_set_version_updates_local(self): ctrl = _make_controller() @@ -93,7 +91,7 @@ def test_set_version_broadcasts_to_all_workers(self): class TestControllerGetVersion: - """Test GatewayInferenceController.get_version.""" + """Test RolloutControllerV2.get_version.""" def test_get_version_returns_local(self): ctrl = _make_controller(version=0) diff --git a/tests/experimental/inference_service/test_examples.py b/tests/experimental/inference_service/test_examples.py index a91c552e56..763e4e5ee7 100644 --- a/tests/experimental/inference_service/test_examples.py +++ b/tests/experimental/inference_service/test_examples.py @@ -79,7 +79,8 @@ def wait_for_pattern(process, pattern, timeout, raise_on_exit=True): f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", f"actor.path={model_path}", "scheduler.type=local", - f"rollout.openai.admin_api_key={admin_api_key}", + f"rollout.admin_api_key={admin_api_key}", + "rollout._version=v2", "stats_logger.wandb.mode=disabled", ] @@ -245,6 +246,7 @@ def test_tau2_rollout(tmp_path_factory): "cluster.n_gpus_per_node=2", f"cluster.fileroot={str(experiments_path)}", f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", + "rollout.admin_api_key=test-admin-key", f"model_path={model_path}", "train_dataset.batch_size=2", "train_dataset.path=tau2/train", diff --git a/tests/experimental/weight_update/test_disk_integration.py b/tests/experimental/weight_update/test_disk_integration.py index fc4fde69f5..6824ba2894 100644 --- a/tests/experimental/weight_update/test_disk_integration.py +++ b/tests/experimental/weight_update/test_disk_integration.py @@ -360,15 +360,13 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): from areal.api import FinetuneSpec from areal.api.cli_args import ( + InferenceEngineConfig, OptimizerConfig, SchedulingSpec, TrainEngineConfig, ) - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -385,9 +383,8 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): scheduler = _make_local_scheduler(tmp, "disk_e2e", gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_half}", scheduling_spec=( SchedulingSpec( @@ -400,7 +397,7 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=f"fsdp:d{n_half}", @@ -430,7 +427,7 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): # -- 1. SGLang via inference controller ---------------------------- inf_ctrl.initialize( role="rollout", - server_args={"mem_fraction_static": 0.7}, + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, ) inf_worker_urls = list(inf_ctrl._inf_addrs) diff --git a/tests/experimental/weight_update/test_nccl_integration.py b/tests/experimental/weight_update/test_nccl_integration.py index cd6f443a0c..7ceca68423 100644 --- a/tests/experimental/weight_update/test_nccl_integration.py +++ b/tests/experimental/weight_update/test_nccl_integration.py @@ -308,15 +308,13 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): from areal.api import FinetuneSpec from areal.api.cli_args import ( + InferenceEngineConfig, OptimizerConfig, SchedulingSpec, TrainEngineConfig, ) - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -332,9 +330,8 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): scheduler = _make_local_scheduler(tmp, "e2e", gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_half}", scheduling_spec=( SchedulingSpec( @@ -347,7 +344,7 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=f"fsdp:d{n_half}", @@ -377,7 +374,7 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): # -- 1. SGLang via inference controller ---------------------------- inf_ctrl.initialize( role="rollout", - server_args={"mem_fraction_static": 0.7}, + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, ) inf_worker_urls = list(inf_ctrl._inf_addrs) @@ -510,12 +507,14 @@ def _run_megatron_awex_e2e( model_path: str | None = None, ): from areal.api import FinetuneSpec - from areal.api.cli_args import OptimizerConfig, SchedulingSpec, TrainEngineConfig - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -530,9 +529,8 @@ def _run_megatron_awex_e2e( model_path = model_path or _get_test_model_path() scheduler = _make_local_scheduler(tmp, tag, gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_infer}", scheduling_spec=( SchedulingSpec( @@ -545,7 +543,7 @@ def _run_megatron_awex_e2e( setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=backend, @@ -571,7 +569,10 @@ def _run_megatron_awex_e2e( wu_ctrl: WeightUpdateController | None = None try: - inf_ctrl.initialize(role="rollout", server_args={"mem_fraction_static": 0.7}) + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + ) inf_worker_urls = list(inf_ctrl._inf_addrs) for url in inf_worker_urls: