From ed9558da24cff4722bde289a808b163e03209343 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sat, 25 Apr 2026 17:45:56 +0000 Subject: [PATCH 01/12] refactor(service): unify experimental controller configs Consolidate the experimental agent and rollout controller configuration into areal.api.cli_args so the trainer, examples, and tests share one configuration surface. This also wires RolloutControllerV2 into the v2 rollout path and updates examples and integrations to use the new controller APIs. Key changes: - move agent and rollout controller configs into areal.api.cli_args and remove duplicated controller config modules - rename and rewire experimental controllers around AgentController and RolloutControllerV2 - update examples and experimental tests for rollout v2 config, initialization, and versioning Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- areal/api/cli_args.py | 122 +++++++++++ areal/experimental/agent_service/README.md | 5 +- areal/experimental/agent_service/__init__.py | 2 +- .../agent_service/controller/__init__.py | 9 +- .../agent_service/controller/config.py | 63 ------ .../agent_service/controller/controller.py | 16 +- .../inference_service/controller/__init__.py | 8 +- .../inference_service/controller/config.py | 66 ------ .../controller/controller.py | 100 +++++---- .../inference_service/controller/workflow.py | 4 +- .../data_service/controller/controller.py | 4 +- areal/trainer/rl_trainer.py | 12 +- areal/utils/logging.py | 4 +- examples/agent_service/README.md | 8 +- examples/agent_service/run_agent_service.py | 10 +- .../experimental/inference_service/README.md | 2 +- .../human_in_the_loop_demo.py | 3 +- .../inference_service/online_rollout.py | 30 +-- .../inference_service/online_rollout.yaml | 2 + .../inference_service/tau2_rollout.py | 35 +--- .../inference_service/tau2_rollout.yaml | 2 + .../agent_service/test_controller.py | 34 ++-- .../experimental/agent_service/test_guard.py | 2 +- .../inference_service/test_controller.py | 192 +++++++++--------- .../test_controller_integration.py | 92 ++++----- .../test_controller_version.py | 20 +- .../inference_service/test_examples.py | 4 +- .../weight_update/test_disk_integration.py | 13 +- .../weight_update/test_nccl_integration.py | 33 +-- 29 files changed, 436 insertions(+), 461 deletions(-) delete mode 100644 areal/experimental/agent_service/controller/config.py delete mode 100644 areal/experimental/inference_service/controller/config.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index edad30414d..2a7f76801f 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1963,6 +1963,60 @@ def __post_init__(self): raise ValueError("admin_api_key must not be empty or whitespace-only") +@dataclass +class AgentConfig: + """Configuration for the experimental agent service controller.""" + + agent_cls_path: str = field( + default="", + metadata={ + "help": "Fully-qualified import path for the AgentRunnable implementation." + }, + ) + admin_api_key: str = field( + default="areal-agent-admin", + metadata={"help": "Shared admin API key for agent-service inter-service auth."}, + ) + num_pairs: int = field( + default=1, + metadata={"help": "Number of Worker+DataProxy pairs to launch on initialize."}, + ) + setup_timeout: float = field( + default=120.0, + metadata={"help": "Timeout in seconds waiting for each service to become healthy."}, + ) + health_poll_interval: float = field( + default=5.0, + metadata={"help": "Seconds between pair health polls; 0 disables health monitoring."}, + ) + drain_timeout: float = field( + default=30.0, + metadata={"help": "Seconds to wait for active sessions to drain before force-killing a pair."}, + ) + log_level: str = field( + default="info", + metadata={"help": "Log level for spawned agent-service micro-services."}, + ) + env: dict[str, str] = field( + default_factory=dict, + metadata={"help": "Extra environment variables passed to all forked child processes."}, + ) + + def __post_init__(self) -> None: + if not self.agent_cls_path: + raise ValueError("agent_cls_path must be a non-empty import path") + if self.num_pairs < 0: + raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") + if self.setup_timeout <= 0: + raise ValueError( + f"setup_timeout must be positive, got {self.setup_timeout}" + ) + if self.drain_timeout < 0: + raise ValueError( + f"drain_timeout must be non-negative, got {self.drain_timeout}" + ) + + @dataclass class InferenceEngineConfig: """Configuration for inference servers, including offpolicyness control.""" @@ -2081,6 +2135,55 @@ class InferenceEngineConfig: }, ) + # v2 controller options + _version: str = field( + default="v1", + metadata={ + "help": "Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2.", + "choices": ["v1", "v2"], + }, + ) + model: str = field( + default="default", + metadata={"help": "Model name exposed through the inference-service gateway."}, + ) + routing_strategy: str = field( + default="round_robin", + metadata={"help": "Routing strategy for the inference-service router."}, + ) + poll_interval: float = field( + default=5.0, + metadata={"help": "Health-poll interval in seconds for the inference-service router."}, + ) + set_reward_finish_timeout: float = field( + default=0.0, + metadata={ + "help": "Timeout in seconds to wait for additional reward updates before finalizing a session." + }, + ) + log_level: str = field( + default="info", + metadata={"help": "Log level for inference-service micro-services."}, + ) + admin_api_key: str = field( + default="areal-admin-key", + metadata={ + "help": "Admin API key used by the inference-service gateway, router, and data proxies." + }, + ) + api_url: str | None = field( + default=None, + metadata={"help": "External OpenAI-compatible base URL for inference-service external model mode."}, + ) + provider_api_key: str | None = field( + default=None, + metadata={"help": "API key for the external OpenAI-compatible provider."}, + ) + n_gpus_per_node: int | None = field( + default=None, + metadata={"help": "GPUs per physical node for multinode inference-service launch."}, + ) + def __post_init__(self): """Validate scheduling_spec length.""" if len(self.scheduling_spec) not in (1, 2): @@ -2088,6 +2191,25 @@ def __post_init__(self): f"scheduling_spec must contain 1 or 2 SchedulingSpec, " f"got {len(self.scheduling_spec)}" ) + if self._version not in ("v1", "v2"): + raise ValueError( + f"_version must be either 'v1' or 'v2', got '{self._version}'" + ) + if self.n_gpus_per_node is not None and self.n_gpus_per_node < 1: + raise ValueError( + f"n_gpus_per_node must be >= 1, got {self.n_gpus_per_node}" + ) + if not self.admin_api_key or not self.admin_api_key.strip(): + raise ValueError("admin_api_key must not be empty or whitespace-only") + if ( + self._version == "v2" + and self.openai is not None + and self.openai.admin_api_key != "areal-admin-key" + ): + logger.warning( + "rollout.openai.admin_api_key is ignored by rollout controller v2; " + "use rollout.admin_api_key instead." + ) @dataclass diff --git a/areal/experimental/agent_service/README.md b/areal/experimental/agent_service/README.md index f3dc0f839f..30f76bb8ec 100644 --- a/areal/experimental/agent_service/README.md +++ b/areal/experimental/agent_service/README.md @@ -159,9 +159,8 @@ areal/experimental/agent_service/ ├── protocol.py # Gateway protocol frame types ├── types.py # AgentRequest, AgentResponse, EventEmitter, AgentRunnable ├── controller/ -│ ├── __init__.py # AgentServiceController, AgentServiceControllerConfig -│ ├── config.py # AgentServiceControllerConfig dataclass -│ └── controller.py # AgentServiceController orchestrator +│ ├── __init__.py # AgentController export +│ └── controller.py # AgentController orchestrator ├── guard/ │ ├── __init__.py # Module docstring │ ├── __main__.py # python -m areal.experimental.agent_service.guard diff --git a/areal/experimental/agent_service/__init__.py b/areal/experimental/agent_service/__init__.py index 3858d5c133..2c964e550f 100644 --- a/areal/experimental/agent_service/__init__.py +++ b/areal/experimental/agent_service/__init__.py @@ -8,7 +8,7 @@ Submodules ---------- -- ``controller`` — :class:`AgentServiceController` orchestrator +- ``controller`` — :class:`AgentController` orchestrator - ``gateway`` — public HTTP/WebSocket entry point - ``router`` — session-affine routing - ``data_proxy`` — stateful session proxy diff --git a/areal/experimental/agent_service/controller/__init__.py b/areal/experimental/agent_service/controller/__init__.py index 3150205885..a677fc8b08 100644 --- a/areal/experimental/agent_service/controller/__init__.py +++ b/areal/experimental/agent_service/controller/__init__.py @@ -2,10 +2,11 @@ """Agent Service Controller — orchestrator for agent micro-services.""" -from .config import AgentServiceControllerConfig -from .controller import AgentServiceController +from areal.api.cli_args import AgentConfig + +from .controller import AgentController __all__ = [ - "AgentServiceController", - "AgentServiceControllerConfig", + "AgentController", + "AgentConfig", ] diff --git a/areal/experimental/agent_service/controller/config.py b/areal/experimental/agent_service/controller/config.py deleted file mode 100644 index c316d58227..0000000000 --- a/areal/experimental/agent_service/controller/config.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Configuration for the AgentServiceController.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - -from ..auth import DEFAULT_ADMIN_API_KEY - - -@dataclass -class AgentServiceControllerConfig: - """Unified configuration for AgentServiceController. - - Consolidates settings for the guard, router, gateway, worker, and - data proxy micro-services launched by the controller. - """ - - # -- Agent class ------------------------------------------------------- - agent_cls_path: str = "" - """Fully-qualified import path for the ``AgentRunnable`` implementation - (e.g. ``examples.agent_service.agent.Tau2Agent``).""" - - # -- Authentication ---------------------------------------------------- - admin_api_key: str = DEFAULT_ADMIN_API_KEY - """Shared admin API key for inter-service Bearer auth.""" - - # -- Scaling ----------------------------------------------------------- - num_pairs: int = 1 - """Number of Worker+DataProxy pairs to launch on initialize.""" - - # -- Timeouts ---------------------------------------------------------- - setup_timeout: float = 120.0 - """Timeout (seconds) waiting for each service to become healthy.""" - - health_poll_interval: float = 5.0 - """Seconds between health polls for crash detection (0 = disabled).""" - - drain_timeout: float = 30.0 - """Seconds to wait for active sessions to drain before force-killing a pair.""" - - # -- Log level --------------------------------------------------------- - log_level: str = "info" - """Log level for spawned micro-services.""" - - # -- Environment ------------------------------------------------------- - env: dict[str, str] = field(default_factory=dict) - """Extra environment variables to pass to all forked child processes.""" - - def __post_init__(self) -> None: - if not self.agent_cls_path: - raise ValueError("agent_cls_path must be a non-empty import path") - if self.num_pairs < 0: - raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") - if self.setup_timeout <= 0: - raise ValueError( - f"setup_timeout must be positive, got {self.setup_timeout}" - ) - if self.drain_timeout < 0: - raise ValueError( - f"drain_timeout must be non-negative, got {self.drain_timeout}" - ) diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py index 21b12851bb..e7465e52be 100644 --- a/areal/experimental/agent_service/controller/controller.py +++ b/areal/experimental/agent_service/controller/controller.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -"""AgentServiceController — orchestrates agent service micro-services via Guards. +"""AgentController — orchestrates agent service micro-services via Guards. Mirrors the architecture of -:class:`~areal.experimental.inference_service.controller.controller.GatewayInferenceController`: +:class:`~areal.experimental.inference_service.controller.controller.RolloutControllerV2`: Guard workers are created via the Scheduler, then the controller forks Router, Worker+DataProxy pairs, and Gateway onto them via HTTP API. @@ -12,7 +12,7 @@ from areal.infra.scheduler.local import LocalScheduler scheduler = LocalScheduler(...) - controller = AgentServiceController(config, scheduler) + controller = AgentController(config, scheduler) controller.initialize() # ... run traffic ... controller.scale_up(2) # add 2 Worker+DataProxy pairs @@ -32,16 +32,14 @@ import requests -from areal.experimental.agent_service.controller.config import ( - AgentServiceControllerConfig, -) +from areal.api.cli_args import AgentConfig from areal.utils import logging from areal.utils.network import format_hostport if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker -logger = logging.getLogger("AgentServiceController") +logger = logging.getLogger("AgentController") _GUARD_ROLE = "agent-guard" _UNREGISTER_RETRIES = 3 @@ -60,7 +58,7 @@ class _WorkerPair: worker_addr: str -class AgentServiceController: +class AgentController: """Orchestrator for the Agent Service micro-service stack. Parameters @@ -73,7 +71,7 @@ class AgentServiceController: def __init__( self, - config: AgentServiceControllerConfig, + config: AgentConfig, scheduler: Scheduler, ) -> None: self.config = config diff --git a/areal/experimental/inference_service/controller/__init__.py b/areal/experimental/inference_service/controller/__init__.py index e122baa6db..9ac57d0d68 100644 --- a/areal/experimental/inference_service/controller/__init__.py +++ b/areal/experimental/inference_service/controller/__init__.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) -__all__ = ["GatewayControllerConfig", "GatewayInferenceController"] +__all__ = ["InferenceEngineConfig", "RolloutControllerV2"] diff --git a/areal/experimental/inference_service/controller/config.py b/areal/experimental/inference_service/controller/config.py deleted file mode 100644 index 4d45b1391b..0000000000 --- a/areal/experimental/inference_service/controller/config.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Configuration for the GatewayInferenceController.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - - -@dataclass -class GatewayControllerConfig: - """Unified configuration for GatewayInferenceController. - - Consolidates settings for the gateway, router, data proxy services, - and the WorkflowExecutor / staleness management. - """ - - # -- Model / tokenizer ------------------------------------------------- - tokenizer_path: str = "" - model_path: str = "" - model: str = "default" - - # -- Routing ----------------------------------------------------------- - routing_strategy: str = "round_robin" - poll_interval: float = 5.0 # router health-poll interval (seconds) - - # -- HTTP timeouts ----------------------------------------------------- - request_timeout: float = 120.0 # per-request timeout (seconds) - setup_timeout: float = 300.0 # timeout waiting for services to start - set_reward_finish_timeout: float = 0.0 - - # -- Log level for gateway micro-services ------------------------------ - log_level: str = "info" - - # -- WorkflowExecutor / staleness -------------------------------------- - consumer_batch_size: int = 16 - max_concurrent_rollouts: int | None = None - max_head_offpolicyness: int = 0 - queue_size: int | None = None - enable_rollout_tracing: bool = False - - # -- Trajectory dump --------------------------------------------------- - fileroot: str | None = None - experiment_name: str | None = None - trial_name: str | None = None - check_trajectory_format: bool = False - dump_to_file: bool = False - - # -- Scheduler / allocation (passed through from trainer) -------------- - backend: str = "sglang:d1" - scheduling_spec: tuple = field(default_factory=tuple) - pause_grace_period: float = 0.5 - n_gpus_per_node: int | None = None # GPUs per physical node; None = single-node - - # -- Admin / workflow -------------------------------------------------- - admin_api_key: str | None = None - turn_discount: float = 1.0 - export_style: str = "individual" - tool_call_parser: str = "qwen" - reasoning_parser: str = "qwen3" - engine_max_tokens: int | None = None - chat_template_type: str = "hf" - - # -- External model API ------------------------------------------------ - api_url: str | None = None - provider_api_key: str | None = None diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index 7cc546a554..57f511c61d 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""GatewayInferenceController — parallel implementation to RolloutController. +"""RolloutControllerV2 — parallel implementation to RolloutController. Routes inference and pause/continue traffic through the gateway HTTP stack (Gateway → Router → Data Proxy → inference backend). @@ -11,6 +11,7 @@ from __future__ import annotations import asyncio +import copy import os import sys import threading @@ -28,14 +29,12 @@ if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker +from areal.api.cli_args import InferenceEngineConfig from areal.api.io_struct import LocalInfServerInfo -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) from areal.utils import logging from areal.utils.network import format_hostport -logger = logging.getLogger("GatewayInferenceController") +logger = logging.getLogger("RolloutControllerV2") _MAX_COMPLETED_ONLINE_RESULTS = 1024 @@ -48,7 +47,7 @@ class _OnlineWaiter: class _DummyDataLoader: """Minimal dataloader that yields a single batch of empty dicts. - Used by :meth:`GatewayInferenceController.prepare_batch` when + Used by :meth:`RolloutControllerV2.prepare_batch` when ``dataloader`` is ``None`` (online-agent mode). """ @@ -59,7 +58,7 @@ def __iter__(self): yield [{} for _ in range(self.batch_size)] -class GatewayInferenceController: +class RolloutControllerV2: """Inference controller that routes everything through the gateway HTTP stack. This is a **parallel** implementation to ``RolloutController`` (NOT a @@ -79,15 +78,15 @@ class GatewayInferenceController: def __init__( self, - config: GatewayControllerConfig, + config: InferenceEngineConfig, scheduler: Scheduler, ) -> None: - if config.admin_api_key is None: + if config.admin_api_key is None or not config.admin_api_key.strip(): raise ValueError( - "GatewayControllerConfig.admin_api_key must be set (not None)" + "InferenceEngineConfig.admin_api_key must be set (not None or empty)" ) if not config.model: - raise ValueError("GatewayControllerConfig.model must not be empty") + raise ValueError("InferenceEngineConfig.model must not be empty") self.config = config self.scheduler = scheduler @@ -199,7 +198,6 @@ def initialize( self._register_data_proxies_in_router() # Create WorkflowExecutor directly (no intermediate engine) - from areal.api.cli_args import InferenceEngineConfig from areal.infra.remote_inf_engine import RemoteInfEngine from areal.infra.workflow_executor import WorkflowExecutor @@ -222,7 +220,7 @@ def initialize( max_staleness=self.config.max_head_offpolicyness, ) - logger.info("GatewayInferenceController initialized (role=%s)", role) + logger.info("RolloutControllerV2 initialized (role=%s)", role) if self.config.model: self.register_model( @@ -237,6 +235,18 @@ def initialize( self.config.model, ) + def offload(self) -> None: + """Offload hook placeholder for trainer compatibility.""" + logger.warning( + "RolloutControllerV2.offload is not implemented and will be skipped" + ) + + def onload(self, tags: list[str] | None = None) -> None: + """Onload hook placeholder for trainer compatibility.""" + logger.warning( + "RolloutControllerV2.onload is not implemented and will be skipped" + ) + async def _async_initialize( self, server_args: dict[str, Any] | None, @@ -266,6 +276,7 @@ async def _async_initialize( cfg = self.config admin_api_key = self.config.admin_api_key + openai_cfg = self._openai_config if self.external_mode: dp_size = 1 @@ -350,10 +361,9 @@ async def _async_initialize( if inf_backend == "sglang": from areal.api.cli_args import SGLangConfig - sglang_config = SGLangConfig( - model_path=cfg.model_path or cfg.tokenizer_path, - ) + sglang_config = SGLangConfig() if server_args: + sglang_config = copy.deepcopy(sglang_config) for k, v in server_args.items(): if hasattr(sglang_config, k): setattr(sglang_config, k, v) @@ -386,17 +396,19 @@ def _build_launch_cmd( elif inf_backend == "vllm": from areal.api.cli_args import vLLMConfig - vllm_config = vLLMConfig(model=cfg.model_path or cfg.tokenizer_path) - for k, v in (server_args or {}).items(): - if hasattr(vllm_config, k): - setattr(vllm_config, k, v) - else: - logger.warning( - "vLLMConfig has no attribute %r, ignoring " - "server_args entry (value=%r)", - k, - v, - ) + vllm_config = vLLMConfig() + if server_args: + vllm_config = copy.deepcopy(vllm_config) + for k, v in server_args.items(): + if hasattr(vllm_config, k): + setattr(vllm_config, k, v) + else: + logger.warning( + "vLLMConfig has no attribute %r, ignoring " + "server_args entry (value=%r)", + k, + v, + ) def _build_launch_cmd( host: str | None, @@ -583,16 +595,16 @@ def _build_launch_cmd( "--callback-server-addr", f"http://{self.callback_addr}", "--tool-call-parser", - cfg.tool_call_parser, + openai_cfg.tool_call_parser, "--reasoning-parser", - cfg.reasoning_parser, + openai_cfg.reasoning_parser, "--chat-template-type", - cfg.chat_template_type, + openai_cfg.chat_template_type, ] - if cfg.engine_max_tokens is not None: + if openai_cfg.engine_max_tokens is not None: data_proxy_base_cmd += [ "--engine-max-tokens", - str(cfg.engine_max_tokens), + str(openai_cfg.engine_max_tokens), ] for group_idx in range(dp_size): @@ -952,7 +964,7 @@ def get_version(self) -> int: def get_capacity(self) -> int: if self.staleness_manager is None: raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" + "RolloutControllerV2.initialize() must be called first" ) return self.staleness_manager.get_capacity() @@ -1045,7 +1057,7 @@ def rollout_batch( """ if not self._gateway_addr: raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" + "RolloutControllerV2.initialize() must be called first" ) if data is None: if batch_size is None: @@ -1116,7 +1128,7 @@ def prepare_batch( """ if not self._gateway_addr: raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" + "RolloutControllerV2.initialize() must be called first" ) if dataloader is None: if batch_size is None: @@ -1322,7 +1334,7 @@ def staleness_manager(self): def workflow_executor(self): if self._workflow_executor is None: raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" + "RolloutControllerV2.initialize() must be called first" ) return self._workflow_executor @@ -1353,8 +1365,8 @@ def _wrap_agent(self, agent: Any): "Gateway address is unavailable; initialize the controller first" ) - openai_cfg = self.config - admin_api_key = openai_cfg.admin_api_key + openai_cfg = self._openai_config + admin_api_key = self.config.admin_api_key turn_discount = openai_cfg.turn_discount export_style = openai_cfg.export_style @@ -1440,13 +1452,13 @@ def _resolve_workflow( # (c) Reject RolloutWorkflow classes and instances if isinstance(agent, type) and issubclass(agent, RolloutWorkflow): raise TypeError( - "GatewayInferenceController only accepts agent classes or instances with a " + "RolloutControllerV2 only accepts agent classes or instances with a " "run() method or None for online mode; direct RolloutWorkflow " "classes are not supported" ) if isinstance(agent, RolloutWorkflow): raise TypeError( - "GatewayInferenceController only accepts agent classes or instances with a " + "RolloutControllerV2 only accepts agent classes or instances with a " "run() method or None for online mode; direct RolloutWorkflow " "instances are not supported" ) @@ -1490,6 +1502,14 @@ def _resolve_should_accept_fn( return cast(Callable[[dict[str, Any]], bool], func) raise TypeError(f"Invalid should_accept_fn type: {type(should_accept_fn)}") + @property + def _openai_config(self): + from areal.api.cli_args import OpenAIProxyConfig + + return self.config.openai or OpenAIProxyConfig( + admin_api_key=self.config.admin_api_key + ) + # -- Internal HTTP helpers --------------------------------------------- def _fork_on_guard( diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index 95f0571770..3bbfeb8b97 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from areal.api.engine_api import InferenceEngine from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.openai.types import InteractionWithTokenLogpReward @@ -29,7 +29,7 @@ class InferenceServiceWorkflow(RolloutWorkflow): def __init__( self, - controller: GatewayInferenceController, + controller: RolloutControllerV2, agent: Any | None = None, gateway_addr: str = "", admin_api_key: str = "areal-admin-key", diff --git a/areal/infra/data_service/controller/controller.py b/areal/infra/data_service/controller/controller.py index 5bfd2ae7ab..67188b8150 100644 --- a/areal/infra/data_service/controller/controller.py +++ b/areal/infra/data_service/controller/controller.py @@ -5,7 +5,7 @@ Manages the full lifecycle: create RPCGuard workers → fork DataWorkers, Router, Gateway → register datasets → serve batches → shutdown. -Follows the same patterns as ``GatewayInferenceController``. +Follows the same patterns as ``RolloutControllerV2``. """ from __future__ import annotations @@ -33,7 +33,7 @@ class DataController: """Controller for the distributed data loading service. - API follows ``TrainController`` / ``GatewayInferenceController`` patterns: + API follows ``TrainController`` / ``RolloutControllerV2`` patterns: ``__init__(config, scheduler)`` then ``initialize(role, ...)``. """ diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 978a6c18de..61fd4f994a 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -6,7 +6,7 @@ import os from collections.abc import Callable from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader @@ -35,6 +35,9 @@ vLLMConfig, ) from areal.engine import RemoteSGLangEngine, RemotevLLMEngine +from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, +) from areal.infra import ( LocalScheduler, RayScheduler, @@ -1008,7 +1011,12 @@ def _init_rollout( return engine # Single-controller mode - no engine instantiation needed - controller = engine_cls.as_controller(config, self.scheduler) + if config._version == "v2": + controller = RolloutControllerV2( + config=config, scheduler=cast(Scheduler, self.scheduler) + ) + else: + controller = engine_cls.as_controller(config, self.scheduler) init_kwargs = dict( role="rollout", server_args=server_args, diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 3d4f83bcd3..8b1130af8c 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -110,9 +110,9 @@ "AgentRouter": "light_purple", "AgentWorker": "light_purple", "AgentDataProxy": "light_purple", - "AgentServiceController": "light_purple", + "AgentController": "light_purple", # Inference service - white (orchestration) - "GatewayInferenceController": "white", + "RolloutControllerV2": "white", "InferenceDataProxy": "white", "InferenceInfBridge": "white", "InferenceRouter": "white", diff --git a/examples/agent_service/README.md b/examples/agent_service/README.md index 563064b4e5..1c7e8bc839 100644 --- a/examples/agent_service/README.md +++ b/examples/agent_service/README.md @@ -86,14 +86,14 @@ Turn 2: Client → Gateway → Router (same DataProxy) → DataProxy → Worker ```python from areal.experimental.agent_service.controller import ( - AgentServiceController, - AgentServiceControllerConfig, + AgentController, ) +from areal.api.cli_args import AgentConfig from areal.infra.scheduler.local import LocalScheduler scheduler = LocalScheduler(experiment_name="demo", trial_name="run0", gpu_devices=[]) -ctrl = AgentServiceController( - config=AgentServiceControllerConfig( +ctrl = AgentController( + config=AgentConfig( agent_cls_path="examples.agent_service.agent.ClaudeAgent", num_pairs=2, ), diff --git a/examples/agent_service/run_agent_service.py b/examples/agent_service/run_agent_service.py index e96f83f501..80850dc4ca 100644 --- a/examples/agent_service/run_agent_service.py +++ b/examples/agent_service/run_agent_service.py @@ -21,10 +21,8 @@ import httpx -from areal.experimental.agent_service.controller import ( - AgentServiceController, - AgentServiceControllerConfig, -) +from areal.api.cli_args import AgentConfig +from areal.experimental.agent_service.controller import AgentController async def _wait_healthy(url: str, timeout: float = 60.0) -> None: @@ -103,12 +101,12 @@ def main() -> None: gpu_devices=[], ) - ctrl_config = AgentServiceControllerConfig( + ctrl_config = AgentConfig( agent_cls_path="examples.agent_service.agent.ClaudeAgent", admin_api_key=args.admin_api_key, num_pairs=args.num_pairs, ) - ctrl = AgentServiceController(config=ctrl_config, scheduler=scheduler) + ctrl = AgentController(config=ctrl_config, scheduler=scheduler) try: print(f"Initializing with {args.num_pairs} pair(s) ...") diff --git a/examples/experimental/inference_service/README.md b/examples/experimental/inference_service/README.md index 1b7c0af21f..a039974b7b 100644 --- a/examples/experimental/inference_service/README.md +++ b/examples/experimental/inference_service/README.md @@ -1,7 +1,7 @@ # AReaL Inference Service Examples This directory contains two examples that use the AReaL Inference Service -(`GatewayInferenceController`) — an experimental rollout backend that exposes an +(`RolloutControllerV2`) — an experimental rollout backend that exposes an OpenAI-compatible proxy gateway so any external agent runtime can submit chat requests and receive RL training data. diff --git a/examples/experimental/inference_service/human_in_the_loop_demo.py b/examples/experimental/inference_service/human_in_the_loop_demo.py index 96f08c00a8..a6e833512d 100644 --- a/examples/experimental/inference_service/human_in_the_loop_demo.py +++ b/examples/experimental/inference_service/human_in_the_loop_demo.py @@ -283,7 +283,8 @@ def cleanup(signum=None, frame=None): str(config_yaml), f"actor.path={args.actor_path}", f"rollout.backend={args.inference_backend}:d1", - f"rollout.openai.admin_api_key={args.admin_key}", + f"rollout.admin_api_key={args.admin_key}", + "rollout._version=v2", f"rollout.request_timeout={args.request_timeout}", ] if args.api_url: diff --git a/examples/experimental/inference_service/online_rollout.py b/examples/experimental/inference_service/online_rollout.py index 106f8e86d0..dd0bfd1bdb 100644 --- a/examples/experimental/inference_service/online_rollout.py +++ b/examples/experimental/inference_service/online_rollout.py @@ -4,6 +4,7 @@ import argparse import sys +from copy import deepcopy from dataclasses import asdict from pathlib import Path @@ -20,11 +21,8 @@ def main(args: list[str]) -> None: ext_args, remaining = parser.parse_known_args(args) from areal.api.cli_args import PPOConfig, load_expr_config - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.utils import logging from areal.utils.environ import is_single_controller @@ -54,26 +52,8 @@ def main(args: list[str]) -> None: is_external = ext_args.api_url is not None - ctrl_config = GatewayControllerConfig( - tokenizer_path=config.tokenizer_path, - model_path=config.actor.path, - consumer_batch_size=config.rollout.consumer_batch_size, - max_concurrent_rollouts=config.rollout.max_concurrent_rollouts, - max_head_offpolicyness=config.rollout.max_head_offpolicyness, - queue_size=config.rollout.queue_size, - enable_rollout_tracing=config.rollout.enable_rollout_tracing, - fileroot=config.rollout.fileroot, - experiment_name=config.rollout.experiment_name, - trial_name=config.rollout.trial_name, - dump_to_file=False, - backend=config.rollout.backend, - scheduling_spec=config.rollout.scheduling_spec, - setup_timeout=config.rollout.setup_timeout, - request_timeout=config.rollout.request_timeout, - admin_api_key=openai_cfg.admin_api_key, - turn_discount=openai_cfg.turn_discount, - export_style=openai_cfg.export_style, - ) + ctrl_config = deepcopy(config.rollout) + ctrl_config.dump_to_file = False if ext_args.model: ctrl_config.model = ext_args.model if is_external: @@ -91,7 +71,7 @@ def main(args: list[str]) -> None: else: raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") - ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) + ctrl = RolloutControllerV2(config=ctrl_config, scheduler=scheduler) try: ctrl.initialize( role="rollout", diff --git a/examples/experimental/inference_service/online_rollout.yaml b/examples/experimental/inference_service/online_rollout.yaml index 307afdfa32..923a16da9e 100644 --- a/examples/experimental/inference_service/online_rollout.yaml +++ b/examples/experimental/inference_service/online_rollout.yaml @@ -19,6 +19,7 @@ scheduler: type: local rollout: + _version: v2 backend: "sglang:d1" experiment_name: ${experiment_name} trial_name: ${trial_name} @@ -33,6 +34,7 @@ rollout: dump_to_file: true request_timeout: 120.0 setup_timeout: 300.0 + admin_api_key: sk-test123456 openai: mode: online export_style: individual diff --git a/examples/experimental/inference_service/tau2_rollout.py b/examples/experimental/inference_service/tau2_rollout.py index d7c6828216..b78701b023 100644 --- a/examples/experimental/inference_service/tau2_rollout.py +++ b/examples/experimental/inference_service/tau2_rollout.py @@ -1,4 +1,4 @@ -"""Rollout-only script for Tau2 benchmark using GatewayInferenceController. +"""Rollout-only script for Tau2 benchmark using RolloutControllerV2. This example demonstrates how to run rollouts (data generation) without training, using the gateway HTTP stack to route inference requests. @@ -13,6 +13,7 @@ import sys import warnings +from copy import deepcopy from dataclasses import asdict, dataclass, field from typing import Any @@ -27,11 +28,8 @@ TrainDatasetConfig, load_expr_config, ) -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.utils import logging @@ -138,7 +136,7 @@ def get_tau2_dataset( @dataclass class Tau2GatewayRolloutConfig(BaseExperimentConfig): - """Configuration for Tau2 rollout-only with GatewayInferenceController.""" + """Configuration for Tau2 rollout-only with RolloutControllerV2.""" gconfig: GenerationHyperparameters = field( default_factory=GenerationHyperparameters @@ -160,7 +158,8 @@ def main(argv: list[str]) -> None: config, _ = load_expr_config(argv, Tau2GatewayRolloutConfig) econfig = config.econfig - rollout_cfg = config.rollout + rollout_cfg = deepcopy(config.rollout) + rollout_cfg.model = config.model_path # --- Dataset --- train_dataset = get_tau2_dataset( @@ -178,26 +177,6 @@ def main(argv: list[str]) -> None: num_workers=0, # in-process; tau2 dataset is lightweight ) - # --- Build GatewayControllerConfig from YAML rollout section --- - ctrl_config = GatewayControllerConfig( - tokenizer_path=config.tokenizer_path, - model_path=config.model_path, - consumer_batch_size=rollout_cfg.consumer_batch_size, - max_concurrent_rollouts=rollout_cfg.max_concurrent_rollouts, - max_head_offpolicyness=rollout_cfg.max_head_offpolicyness, - queue_size=rollout_cfg.queue_size, - enable_rollout_tracing=rollout_cfg.enable_rollout_tracing, - fileroot=rollout_cfg.fileroot, - experiment_name=rollout_cfg.experiment_name, - trial_name=rollout_cfg.trial_name, - dump_to_file=rollout_cfg.dump_to_file, - backend=rollout_cfg.backend, - scheduling_spec=rollout_cfg.scheduling_spec, - setup_timeout=rollout_cfg.setup_timeout, - request_timeout=rollout_cfg.request_timeout, - openai=rollout_cfg.openai, - ) - # --- Scheduler --- from areal.infra.scheduler.local import LocalScheduler from areal.infra.scheduler.slurm import SlurmScheduler @@ -219,7 +198,7 @@ def main(argv: list[str]) -> None: else: raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") - ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) + ctrl = RolloutControllerV2(config=rollout_cfg, scheduler=scheduler) ctrl.initialize( role="rollout", server_args=server_args, diff --git a/examples/experimental/inference_service/tau2_rollout.yaml b/examples/experimental/inference_service/tau2_rollout.yaml index 13b19a5904..8fd69c6191 100644 --- a/examples/experimental/inference_service/tau2_rollout.yaml +++ b/examples/experimental/inference_service/tau2_rollout.yaml @@ -29,6 +29,7 @@ gconfig: temperature: 1.0 rollout: + _version: v2 backend: "sglang:d2" experiment_name: ${experiment_name} trial_name: ${trial_name} @@ -40,6 +41,7 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: false + admin_api_key: rollout-admin openai: mode: inline tool_call_parser: qwen25 diff --git a/tests/experimental/agent_service/test_controller.py b/tests/experimental/agent_service/test_controller.py index 376bed71c6..bacbd1bbe9 100644 --- a/tests/experimental/agent_service/test_controller.py +++ b/tests/experimental/agent_service/test_controller.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for AgentServiceController. +"""Unit tests for AgentController. All Guard HTTP interactions are mocked — no real processes or servers. Tests cover: initialize, destroy, scale_up, scale_down, and error handling. @@ -14,12 +14,8 @@ import pytest -from areal.experimental.agent_service.controller.config import ( - AgentServiceControllerConfig, -) -from areal.experimental.agent_service.controller.controller import ( - AgentServiceController, -) +from areal.api.cli_args import AgentConfig +from areal.experimental.agent_service.controller.controller import AgentController CTRL = "areal.experimental.agent_service.controller.controller" @@ -83,7 +79,7 @@ def _mock_health_response(active_sessions: int = 0) -> MagicMock: @pytest.fixture() def config(): - return AgentServiceControllerConfig( + return AgentConfig( agent_cls_path="my.Agent", admin_api_key="test-key", num_pairs=2, @@ -116,7 +112,7 @@ def mock_post(url, **kwargs): class TestConstruction: def test_construction(self, config): scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) assert ctrl.router_addr == "" assert ctrl.gateway_addr == "" assert ctrl.pairs == {} @@ -129,7 +125,7 @@ def test_initialize_forks_router_pairs_gateway(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090"), ("10.0.0.2", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() scheduler.create_workers.assert_called_once() @@ -148,7 +144,7 @@ def test_scale_up_adds_pairs(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl.pairs) == 0 @@ -178,7 +174,7 @@ def mock_post(url, **kwargs): mock_requests.RequestException = Exception scheduler = _make_scheduler(("g0", "8090"), ("g1", "8091")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() guards_called.clear() @@ -197,7 +193,7 @@ def test_scale_down_removes_newest_first(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl.pairs) == 3 @@ -214,7 +210,7 @@ def test_destroy_clears_everything(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl._forked_services) > 0 @@ -246,7 +242,7 @@ def mock_post(url, **kwargs): mock_requests.RequestException = Exception scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() ctrl.destroy() @@ -274,7 +270,7 @@ def mock_get(url, **kwargs): mock_requests.get = mock_get scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() health_call_count = 0 @@ -301,7 +297,7 @@ def counting_get(url, **kwargs): mock_requests.get = counting_get scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() pre_get_count = get_count @@ -318,7 +314,7 @@ def test_health_monitor_starts_and_stops(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert ctrl._health_thread is not None assert ctrl._health_thread.is_alive() @@ -333,7 +329,7 @@ def test_health_monitor_disabled_when_interval_zero(self, mock_requests, config) _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert ctrl._health_thread is None diff --git a/tests/experimental/agent_service/test_guard.py b/tests/experimental/agent_service/test_guard.py index e5a82deac3..674808d618 100644 --- a/tests/experimental/agent_service/test_guard.py +++ b/tests/experimental/agent_service/test_guard.py @@ -2,7 +2,7 @@ Tests that the base guard routes are available on the agent guard app. The agent_blueprint has been removed in v2 — all orchestration logic -now lives in AgentServiceController. +now lives in AgentController. Test structure mirrors ``tests/experimental/inference_service/test_guard.py``. """ diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 5e61f768f1..95a929a7e9 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -1,4 +1,4 @@ -"""Tests for GatewayInferenceController.""" +"""Tests for RolloutControllerV2.""" from __future__ import annotations @@ -8,34 +8,33 @@ import httpx import pytest -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig, OpenAIProxyConfig from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.inference_service.controller.workflow import ( InferenceServiceWorkflow, ) # ============================================================================= -# GatewayControllerConfig +# InferenceEngineConfig # ============================================================================= -class TestGatewayControllerConfig: +class TestInferenceEngineConfigForInferenceService: def test_defaults(self): - cfg = GatewayControllerConfig() - assert cfg.admin_api_key is None + cfg = InferenceEngineConfig(backend="sglang:d1") + assert cfg.admin_api_key == "areal-admin-key" assert cfg.model == "default" - assert cfg.consumer_batch_size == 16 + assert cfg.consumer_batch_size == 1 assert cfg.max_concurrent_rollouts is None assert cfg.max_head_offpolicyness == 0 assert cfg.enable_rollout_tracing is False assert cfg.set_reward_finish_timeout == 0.0 def test_custom_values(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="custom-key", consumer_batch_size=32, max_concurrent_rollouts=64, @@ -49,7 +48,8 @@ def test_custom_values(self): assert cfg.set_reward_finish_timeout == 3.0 def test_scheduling_fields(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", request_timeout=60.0, setup_timeout=600.0, ) @@ -57,30 +57,31 @@ def test_scheduling_fields(self): assert cfg.setup_timeout == 600.0 def test_dump_to_file_defaults_to_false(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") assert cfg.dump_to_file is False # ============================================================================= -# GatewayInferenceController — workflow resolution helpers +# RolloutControllerV2 — workflow resolution helpers # ============================================================================= class TestControllerWorkflowResolution: def test_resolve_workflow_with_instance(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) with pytest.raises(TypeError, match=r"callable run\(\) method"): controller._resolve_workflow(12345) def test_resolve_workflow_none_creates_online_inference_service_workflow(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" resolved = controller._resolve_workflow( @@ -94,11 +95,12 @@ def test_resolve_workflow_none_creates_online_inference_service_workflow(self): assert resolved.timeout == 3.0 def test_resolve_workflow_agent_class_creates_offline_workflow(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" class MockAgent: @@ -115,17 +117,17 @@ async def run(self, data, **kwargs): assert isinstance(resolved.agent, MockAgent) def test_resolve_should_accept_fn_none(self): - assert GatewayInferenceController._resolve_should_accept_fn(None) is None + assert RolloutControllerV2._resolve_should_accept_fn(None) is None def test_resolve_should_accept_fn_callable(self): fn = lambda x: True # noqa: E731 - assert GatewayInferenceController._resolve_should_accept_fn(fn) is fn + assert RolloutControllerV2._resolve_should_accept_fn(fn) is fn def test_resolve_workflow_with_agent_class(self): """Test _resolve_workflow wraps agent-like classes in InferenceServiceWorkflow.""" - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" class MockAgent: @@ -141,8 +143,8 @@ async def run(self, data, **kwargs): assert hasattr(resolved, "arun_episode") def test_resolve_workflow_agent_class_without_gateway_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) @@ -154,8 +156,8 @@ async def run(self, data, **kwargs): controller._resolve_workflow(MockAgent, workflow_kwargs={}) def test_resolve_workflow_rollout_workflow_instance_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -172,8 +174,8 @@ def test_resolve_workflow_rollout_workflow_instance_raises(self): controller._resolve_workflow(workflow) def test_resolve_workflow_rollout_workflow_class_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -188,11 +190,11 @@ def test_resolve_workflow_rollout_workflow_class_raises(self): # ============================================================================= -# GatewayInferenceController — API surface +# RolloutControllerV2 — API surface # ============================================================================= -class TestGatewayInferenceControllerAPISurface: +class TestRolloutControllerV2APISurface: def test_has_all_public_methods(self): methods = [ "initialize", @@ -216,7 +218,7 @@ def test_has_all_public_methods(self): "start_proxy_gateway", ] for m in methods: - assert hasattr(GatewayInferenceController, m), f"Missing method: {m}" + assert hasattr(RolloutControllerV2, m), f"Missing method: {m}" def test_has_properties(self): properties = [ @@ -228,35 +230,36 @@ def test_has_properties(self): "worker_ids", ] for p in properties: - assert hasattr(GatewayInferenceController, p), f"Missing property: {p}" + assert hasattr(RolloutControllerV2, p), f"Missing property: {p}" def test_not_subclass_of_rollout_controller(self): - """GatewayInferenceController must NOT be a subclass of RolloutController.""" + """RolloutControllerV2 must NOT be a subclass of RolloutController.""" # Verify it doesn't inherit from any class except object - bases = GatewayInferenceController.__bases__ + bases = RolloutControllerV2.__bases__ assert bases == (object,), f"Unexpected bases: {bases}" # ============================================================================= -# GatewayInferenceController — construction + state +# RolloutControllerV2 — construction + state # ============================================================================= -class TestGatewayInferenceControllerConstruction: +class TestRolloutControllerV2Construction: def test_admin_api_key_none_raises(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") + cfg.admin_api_key = "" with pytest.raises(ValueError, match="admin_api_key must be set"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_model_empty_raises(self): - cfg = GatewayControllerConfig(admin_api_key="test-key", model="") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key", model="") with pytest.raises(ValueError, match="model must not be empty"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_constructor(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) assert controller.config is cfg assert controller.scheduler is scheduler @@ -268,63 +271,63 @@ def test_constructor(self): assert controller.worker_ids == {} def test_admin_api_key_defaults(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) assert controller.config.admin_api_key == "test-key" def test_version_management_without_services(self): """set_version / get_version work even without gateway services.""" - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # No gateway services started, but version management is local controller._version = 42 assert controller.get_version() == 42 def test_export_stats_returns_dict(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) stats = controller.export_stats() assert isinstance(stats, dict) def test_start_proxy_is_noop(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Should not raise controller.start_proxy() controller.start_proxy_gateway() def test_proxy_gateway_addr(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Before initialize, proxy_gateway_addr returns the empty _gateway_addr assert controller.proxy_gateway_addr == "" def test_callback_addr_formats_ipv6_hostport(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "2001:db8::10" controller._callback_port = 19000 assert controller.callback_addr == "[2001:db8::10]:19000" def test_workflow_executor_raises_before_init(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) with pytest.raises(RuntimeError, match="initialize"): _ = controller.workflow_executor def test_config_perf_tracer_is_noop(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Should not raise controller.config_perf_tracer() controller.save_perf_tracer() @@ -343,7 +346,8 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy scheduler = MagicMock() scheduler.get_workers.return_value = [worker] - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", tokenizer_path="mock-tokenizer", request_timeout=15.0, set_reward_finish_timeout=7.5, @@ -357,7 +361,7 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy ), admin_api_key="test-admin-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 @@ -385,15 +389,15 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy # ============================================================================= -# GatewayInferenceController — gateway HTTP helpers +# RolloutControllerV2 — gateway HTTP helpers # ============================================================================= -class TestGatewayInferenceControllerHTTP: +class TestRolloutControllerV2HTTP: def test_gateway_http_post_raises_on_failure(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://127.0.0.1:19999" with pytest.raises(RuntimeError, match="Failed to POST"): controller._gateway_http_post("/test", {"key": "value"}) @@ -404,11 +408,12 @@ def test_gateway_http_post_sends_auth(self, mock_post): mock_resp.status_code = 200 mock_post.return_value = mock_resp - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="my-secret-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://127.0.0.1:8080" controller._gateway_http_post("/test_endpoint", {"data": 1}) @@ -422,11 +427,12 @@ def test_gateway_http_post_sends_auth(self, mock_post): class TestOnlineCallbackFlow: @pytest.mark.asyncio async def test_online_callback_without_waiter_buffers_export_request(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() try: async with httpx.AsyncClient() as client: @@ -443,11 +449,12 @@ async def test_online_callback_without_waiter_buffers_export_request(self): @pytest.mark.asyncio async def test_online_callback_settles_waiter_once(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -470,11 +477,12 @@ async def test_online_callback_settles_waiter_once(self): @pytest.mark.asyncio async def test_online_callback_invalid_payload_keeps_waiter_pending(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -497,11 +505,12 @@ async def test_online_callback_invalid_payload_keeps_waiter_pending(self): @pytest.mark.asyncio async def test_cancelled_waiter_buffers_completed_online_result(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -644,38 +653,39 @@ async def run(self, data, **kwargs): class TestMultiNodeConfig: def test_n_gpus_per_node_default_is_none(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") assert cfg.n_gpus_per_node is None def test_n_gpus_per_node_custom(self): - cfg = GatewayControllerConfig(n_gpus_per_node=4) + cfg = InferenceEngineConfig(backend="sglang:d1", n_gpus_per_node=4) assert cfg.n_gpus_per_node == 4 def test_n_gpus_per_node_zero_raises(self): - cfg = GatewayControllerConfig( - n_gpus_per_node=0, backend="sglang:d1t8", admin_api_key="test-key" + cfg = InferenceEngineConfig( + n_gpus_per_node=1, backend="sglang:d1t8", admin_api_key="test-key" ) + cfg.n_gpus_per_node = 0 with pytest.raises(ValueError, match="n_gpus_per_node must be >= 1"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_gpus_not_divisible_raises(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( n_gpus_per_node=3, backend="sglang:d1t8", admin_api_key="test-key" ) with pytest.raises(ValueError, match="must be divisible by n_gpus_per_node"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_single_node_backward_compat(self): - cfg = GatewayControllerConfig(backend="sglang:d2t4", admin_api_key="test-key") - controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + cfg = InferenceEngineConfig(backend="sglang:d2t4", admin_api_key="test-key") + controller = RolloutControllerV2(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 1 def test_multi_node_valid_config(self): # tp=16, n_gpus_per_node=8 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( n_gpus_per_node=8, backend="sglang:d1t16", admin_api_key="test-key" ) - controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + controller = RolloutControllerV2(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 2 @pytest.mark.asyncio @@ -698,14 +708,14 @@ async def test_async_initialize_multinode_worker_count(self): scheduler.get_workers.return_value = [worker0] # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( tokenizer_path="mock-tokenizer", backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), admin_api_key="test-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 @@ -756,14 +766,14 @@ async def test_async_initialize_multinode_fork_path(self): scheduler.get_workers.return_value = [worker0, worker1] # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( tokenizer_path="mock-tokenizer", backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), admin_api_key="test-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 diff --git a/tests/experimental/inference_service/test_controller_integration.py b/tests/experimental/inference_service/test_controller_integration.py index cdaa0b1868..380694ddb9 100644 --- a/tests/experimental/inference_service/test_controller_integration.py +++ b/tests/experimental/inference_service/test_controller_integration.py @@ -1,4 +1,4 @@ -"""Integration tests for GatewayInferenceController with real SGLang servers. +"""Integration tests for RolloutControllerV2 with real SGLang servers. Requires GPU and a model. Marked @pytest.mark.slow to exclude from default CI. Run manually: @@ -6,8 +6,8 @@ The test launches: 1. A real SGLang server (GPU subprocess) - 2. Module-scoped LocalScheduler / GatewayInferenceController fixtures - 3. A GatewayInferenceController that spins up Gateway, Router, and Data Proxy +2. Module-scoped LocalScheduler / RolloutControllerV2 fixtures +3. A RolloutControllerV2 that spins up Gateway, Router, and Data Proxy micro-services in background threads. """ @@ -121,14 +121,12 @@ def _make_gateway_controller_config( online_mode: bool = False, set_reward_finish_timeout: float = 0.0, ): - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec - return GatewayControllerConfig( + return InferenceEngineConfig( + backend="sglang:d1", tokenizer_path=model_path, - model_path=model_path, + model=model_path, set_reward_finish_timeout=set_reward_finish_timeout, scheduling_spec=( SchedulingSpec( @@ -189,17 +187,17 @@ def _export_trajectory_with_retry( @pytest.fixture(scope="module") def gateway_controller(sglang_server, model_path, tmp_path_factory): - """Create and initialize a GatewayInferenceController, yield it, then destroy.""" + """Create and initialize a RolloutControllerV2, yield it, then destroy.""" if not has_gpu(): pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler(tmp_path_factory, "gateway_controller") config = _make_gateway_controller_config(model_path) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout", @@ -219,14 +217,14 @@ def gateway_controller_online(sglang_server, model_path, tmp_path_factory): pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_online" ) config = _make_gateway_controller_config(model_path, online_mode=True) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout", @@ -246,7 +244,7 @@ def gateway_controller_with_reward_timeout(sglang_server, model_path, tmp_path_f pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler( @@ -256,7 +254,7 @@ def gateway_controller_with_reward_timeout(sglang_server, model_path, tmp_path_f model_path, set_reward_finish_timeout=3.0, ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-timeout", @@ -777,7 +775,7 @@ def test_offline_export_applies_discount_after_multiple_rewards_in_same_trajecto @pytest.fixture(scope="module") def gateway_controller_full_init(model_path, tmp_path_factory): - """Create a GatewayInferenceController that launches SGLang via the full init path. + """Create a RolloutControllerV2 that launches SGLang via the full init path. Unlike ``gateway_controller`` which passes pre-existing ``server_infos``, this fixture lets the controller create RPC workers, create @@ -786,17 +784,14 @@ def gateway_controller_full_init(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, + model=model_path, backend="sglang:d1", scheduling_spec=( SchedulingSpec( @@ -810,6 +805,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): ) server_args = { + "model_path": model_path, "skip_tokenizer_init": True, "mem_fraction_static": 0.15, } @@ -817,7 +813,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-full", server_args=server_args, @@ -1075,17 +1071,14 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, + model=model_path, backend="vllm:d1", scheduling_spec=( SchedulingSpec( @@ -1099,13 +1092,14 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): ) server_args = { + "model": model_path, "gpu_memory_utilization": 0.15, } local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vllm" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vllm", server_args=server_args, @@ -1280,17 +1274,14 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=vlm_model_path, - model_path=vlm_model_path, + model=vlm_model_path, backend="sglang:d1", scheduling_spec=( SchedulingSpec( @@ -1306,10 +1297,14 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vlm_sglang" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vlm-sglang", - server_args={"skip_tokenizer_init": True, "mem_fraction_static": 0.25}, + server_args={ + "model_path": vlm_model_path, + "skip_tokenizer_init": True, + "mem_fraction_static": 0.25, + }, ) try: @@ -1324,17 +1319,14 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=vlm_model_path, - model_path=vlm_model_path, + model=vlm_model_path, backend="vllm:d1", scheduling_spec=( SchedulingSpec( @@ -1350,7 +1342,7 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vlm_vllm" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vlm-vllm", server_args={"gpu_memory_utilization": 0.25}, diff --git a/tests/experimental/inference_service/test_controller_version.py b/tests/experimental/inference_service/test_controller_version.py index b2d90c3a5d..255bc6319a 100644 --- a/tests/experimental/inference_service/test_controller_version.py +++ b/tests/experimental/inference_service/test_controller_version.py @@ -1,4 +1,4 @@ -"""Unit tests for GatewayInferenceController version management. +"""Unit tests for RolloutControllerV2 version management. Tests set_version and get_version with mocked HTTP calls. """ @@ -8,12 +8,9 @@ import asyncio from unittest.mock import AsyncMock, MagicMock, patch -from areal.api.cli_args import SchedulingSpec -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) # ============================================================================= @@ -25,17 +22,18 @@ def _make_controller( gateway_addr: str = "", worker_ids: dict[str, str] | None = None, version: int = 0, -) -> GatewayInferenceController: +) -> RolloutControllerV2: """Create a controller with minimal config and manually injected state. Does NOT call initialize() — internal fields are set directly. """ - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-key", scheduling_spec=(SchedulingSpec(),), ) scheduler = MagicMock() - ctrl = GatewayInferenceController(config=cfg, scheduler=scheduler) + ctrl = RolloutControllerV2(config=cfg, scheduler=scheduler) ctrl._gateway_addr = gateway_addr ctrl._worker_ids = worker_ids or {} ctrl._version = version @@ -48,7 +46,7 @@ def _make_controller( class TestControllerSetVersion: - """Test GatewayInferenceController.set_version.""" + """Test RolloutControllerV2.set_version.""" def test_set_version_updates_local(self): ctrl = _make_controller() @@ -93,7 +91,7 @@ def test_set_version_broadcasts_to_all_workers(self): class TestControllerGetVersion: - """Test GatewayInferenceController.get_version.""" + """Test RolloutControllerV2.get_version.""" def test_get_version_returns_local(self): ctrl = _make_controller(version=0) diff --git a/tests/experimental/inference_service/test_examples.py b/tests/experimental/inference_service/test_examples.py index a91c552e56..763e4e5ee7 100644 --- a/tests/experimental/inference_service/test_examples.py +++ b/tests/experimental/inference_service/test_examples.py @@ -79,7 +79,8 @@ def wait_for_pattern(process, pattern, timeout, raise_on_exit=True): f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", f"actor.path={model_path}", "scheduler.type=local", - f"rollout.openai.admin_api_key={admin_api_key}", + f"rollout.admin_api_key={admin_api_key}", + "rollout._version=v2", "stats_logger.wandb.mode=disabled", ] @@ -245,6 +246,7 @@ def test_tau2_rollout(tmp_path_factory): "cluster.n_gpus_per_node=2", f"cluster.fileroot={str(experiments_path)}", f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", + "rollout.admin_api_key=test-admin-key", f"model_path={model_path}", "train_dataset.batch_size=2", "train_dataset.path=tau2/train", diff --git a/tests/experimental/weight_update/test_disk_integration.py b/tests/experimental/weight_update/test_disk_integration.py index fc4fde69f5..6824ba2894 100644 --- a/tests/experimental/weight_update/test_disk_integration.py +++ b/tests/experimental/weight_update/test_disk_integration.py @@ -360,15 +360,13 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): from areal.api import FinetuneSpec from areal.api.cli_args import ( + InferenceEngineConfig, OptimizerConfig, SchedulingSpec, TrainEngineConfig, ) - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -385,9 +383,8 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): scheduler = _make_local_scheduler(tmp, "disk_e2e", gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_half}", scheduling_spec=( SchedulingSpec( @@ -400,7 +397,7 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=f"fsdp:d{n_half}", @@ -430,7 +427,7 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): # -- 1. SGLang via inference controller ---------------------------- inf_ctrl.initialize( role="rollout", - server_args={"mem_fraction_static": 0.7}, + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, ) inf_worker_urls = list(inf_ctrl._inf_addrs) diff --git a/tests/experimental/weight_update/test_nccl_integration.py b/tests/experimental/weight_update/test_nccl_integration.py index cd6f443a0c..7ceca68423 100644 --- a/tests/experimental/weight_update/test_nccl_integration.py +++ b/tests/experimental/weight_update/test_nccl_integration.py @@ -308,15 +308,13 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): from areal.api import FinetuneSpec from areal.api.cli_args import ( + InferenceEngineConfig, OptimizerConfig, SchedulingSpec, TrainEngineConfig, ) - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -332,9 +330,8 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): scheduler = _make_local_scheduler(tmp, "e2e", gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_half}", scheduling_spec=( SchedulingSpec( @@ -347,7 +344,7 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=f"fsdp:d{n_half}", @@ -377,7 +374,7 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): # -- 1. SGLang via inference controller ---------------------------- inf_ctrl.initialize( role="rollout", - server_args={"mem_fraction_static": 0.7}, + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, ) inf_worker_urls = list(inf_ctrl._inf_addrs) @@ -510,12 +507,14 @@ def _run_megatron_awex_e2e( model_path: str | None = None, ): from areal.api import FinetuneSpec - from areal.api.cli_args import OptimizerConfig, SchedulingSpec, TrainEngineConfig - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -530,9 +529,8 @@ def _run_megatron_awex_e2e( model_path = model_path or _get_test_model_path() scheduler = _make_local_scheduler(tmp, tag, gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_infer}", scheduling_spec=( SchedulingSpec( @@ -545,7 +543,7 @@ def _run_megatron_awex_e2e( setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=backend, @@ -571,7 +569,10 @@ def _run_megatron_awex_e2e( wu_ctrl: WeightUpdateController | None = None try: - inf_ctrl.initialize(role="rollout", server_args={"mem_fraction_static": 0.7}) + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + ) inf_worker_urls = list(inf_ctrl._inf_addrs) for url in inf_worker_urls: From bab951a3a8816280ea6c1b8a96639b5a29f6171c Mon Sep 17 00:00:00 2001 From: nuzant Date: Sat, 25 Apr 2026 17:49:27 +0000 Subject: [PATCH 02/12] chore: format --- areal/api/cli_args.py | 28 +++++-- .../controller/controller.py | 16 +--- docs/en/cli_reference.md | 76 +++++++++++++------ docs/zh/cli_reference.md | 76 +++++++++++++------ .../inference_service/test_controller.py | 6 +- 5 files changed, 133 insertions(+), 69 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 2a7f76801f..fac07bc5ac 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1983,15 +1983,21 @@ class AgentConfig: ) setup_timeout: float = field( default=120.0, - metadata={"help": "Timeout in seconds waiting for each service to become healthy."}, + metadata={ + "help": "Timeout in seconds waiting for each service to become healthy." + }, ) health_poll_interval: float = field( default=5.0, - metadata={"help": "Seconds between pair health polls; 0 disables health monitoring."}, + metadata={ + "help": "Seconds between pair health polls; 0 disables health monitoring." + }, ) drain_timeout: float = field( default=30.0, - metadata={"help": "Seconds to wait for active sessions to drain before force-killing a pair."}, + metadata={ + "help": "Seconds to wait for active sessions to drain before force-killing a pair." + }, ) log_level: str = field( default="info", @@ -1999,7 +2005,9 @@ class AgentConfig: ) env: dict[str, str] = field( default_factory=dict, - metadata={"help": "Extra environment variables passed to all forked child processes."}, + metadata={ + "help": "Extra environment variables passed to all forked child processes." + }, ) def __post_init__(self) -> None: @@ -2153,7 +2161,9 @@ class InferenceEngineConfig: ) poll_interval: float = field( default=5.0, - metadata={"help": "Health-poll interval in seconds for the inference-service router."}, + metadata={ + "help": "Health-poll interval in seconds for the inference-service router." + }, ) set_reward_finish_timeout: float = field( default=0.0, @@ -2173,7 +2183,9 @@ class InferenceEngineConfig: ) api_url: str | None = field( default=None, - metadata={"help": "External OpenAI-compatible base URL for inference-service external model mode."}, + metadata={ + "help": "External OpenAI-compatible base URL for inference-service external model mode." + }, ) provider_api_key: str | None = field( default=None, @@ -2181,7 +2193,9 @@ class InferenceEngineConfig: ) n_gpus_per_node: int | None = field( default=None, - metadata={"help": "GPUs per physical node for multinode inference-service launch."}, + metadata={ + "help": "GPUs per physical node for multinode inference-service launch." + }, ) def __post_init__(self): diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index 57f511c61d..45de0ea17d 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -963,9 +963,7 @@ def get_version(self) -> int: def get_capacity(self) -> int: if self.staleness_manager is None: - raise RuntimeError( - "RolloutControllerV2.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") return self.staleness_manager.get_capacity() # -- Submit / Wait / Batch --------------------------------------------- @@ -1056,9 +1054,7 @@ def rollout_batch( A list of trajectory dicts (one per completed rollout). """ if not self._gateway_addr: - raise RuntimeError( - "RolloutControllerV2.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") if data is None: if batch_size is None: raise ValueError( @@ -1127,9 +1123,7 @@ def prepare_batch( A list of trajectory dicts (matching ``RolloutController`` API). """ if not self._gateway_addr: - raise RuntimeError( - "RolloutControllerV2.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") if dataloader is None: if batch_size is None: raise ValueError( @@ -1333,9 +1327,7 @@ def staleness_manager(self): @property def workflow_executor(self): if self._workflow_executor is None: - raise RuntimeError( - "RolloutControllerV2.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") return self._workflow_executor @property diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 8ea5263c78..abec4059ab 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -71,6 +71,7 @@ For detailed examples, see the experiment configurations in the `examples/` dire ### Others +- [Agent Configuration](section-agent) - [ArchonEngine Configuration](section-archon-engine) - [ArchonFP8 Configuration](section-archon-fp8) - [DPO Configuration](section-dpo) @@ -527,30 +528,40 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | -| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | -| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | -| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| Parameter | Type | Default | Description | +| --------------------------- | ---------------------------------------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | +| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | +| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | +| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | +| `model` | string | `"default"` | Model name exposed through the inference-service gateway. | +| `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | +| `poll_interval` | float | `5.0` | Health-poll interval in seconds for the inference-service router. | +| `set_reward_finish_timeout` | float | `0.0` | Timeout in seconds to wait for additional reward updates before finalizing a session. | +| `log_level` | string | `"info"` | Log level for inference-service micro-services. | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by the inference-service gateway, router, and data proxies. | +| `api_url` | string \| None | `None` | External OpenAI-compatible base URL for inference-service external model mode. | +| `provider_api_key` | string \| None | `None` | API key for the external OpenAI-compatible provider. | +| `n_gpus_per_node` | integer \| None | `None` | GPUs per physical node for multinode inference-service launch. | (section-sg-lang)= @@ -856,6 +867,23 @@ Configuration for Weights & Biases experiment tracking. | `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | +(section-agent)= + +## Agent Configuration + +Configuration for the experimental agent service controller. + +| Parameter | Type | Default | Description | +| ---------------------- | ------- | --------------------- | ------------------------------------------------------------------------- | +| `agent_cls_path` | string | `""` | Fully-qualified import path for the AgentRunnable implementation. | +| `admin_api_key` | string | `"areal-agent-admin"` | Shared admin API key for agent-service inter-service auth. | +| `num_pairs` | integer | `1` | Number of Worker+DataProxy pairs to launch on initialize. | +| `setup_timeout` | float | `120.0` | Timeout in seconds waiting for each service to become healthy. | +| `health_poll_interval` | float | `5.0` | Seconds between pair health polls; 0 disables health monitoring. | +| `drain_timeout` | float | `30.0` | Seconds to wait for active sessions to drain before force-killing a pair. | +| `log_level` | string | `"info"` | Log level for spawned agent-service micro-services. | +| `env` | `dict` | **Required** | Extra environment variables passed to all forked child processes. | + (section-archon-engine)= ## ArchonEngine Configuration diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index b1ca97aa3e..ebff40ba84 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -69,6 +69,7 @@ python3 train.py --config path/to/config.yaml actor.lr=1e-4 seed=42 ### Others +- [Agent Configuration](section-agent) - [ArchonEngine Configuration](section-archon-engine) - [ArchonFP8 Configuration](section-archon-fp8) - [DPO Configuration](section-dpo) @@ -525,30 +526,40 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | -| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | -| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | -| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| Parameter | Type | Default | Description | +| --------------------------- | ---------------------------------------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | +| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | +| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | +| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | +| `model` | string | `"default"` | Model name exposed through the inference-service gateway. | +| `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | +| `poll_interval` | float | `5.0` | Health-poll interval in seconds for the inference-service router. | +| `set_reward_finish_timeout` | float | `0.0` | Timeout in seconds to wait for additional reward updates before finalizing a session. | +| `log_level` | string | `"info"` | Log level for inference-service micro-services. | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by the inference-service gateway, router, and data proxies. | +| `api_url` | string \| None | `None` | External OpenAI-compatible base URL for inference-service external model mode. | +| `provider_api_key` | string \| None | `None` | API key for the external OpenAI-compatible provider. | +| `n_gpus_per_node` | integer \| None | `None` | GPUs per physical node for multinode inference-service launch. | (section-sg-lang)= @@ -854,6 +865,23 @@ Configuration for Weights & Biases experiment tracking. | `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | +(section-agent)= + +## Agent Configuration + +Configuration for the experimental agent service controller. + +| Parameter | Type | Default | Description | +| ---------------------- | ------- | --------------------- | ------------------------------------------------------------------------- | +| `agent_cls_path` | string | `""` | Fully-qualified import path for the AgentRunnable implementation. | +| `admin_api_key` | string | `"areal-agent-admin"` | Shared admin API key for agent-service inter-service auth. | +| `num_pairs` | integer | `1` | Number of Worker+DataProxy pairs to launch on initialize. | +| `setup_timeout` | float | `120.0` | Timeout in seconds waiting for each service to become healthy. | +| `health_poll_interval` | float | `5.0` | Seconds between pair health polls; 0 disables health monitoring. | +| `drain_timeout` | float | `30.0` | Seconds to wait for active sessions to drain before force-killing a pair. | +| `log_level` | string | `"info"` | Log level for spawned agent-service micro-services. | +| `env` | `dict` | **Required** | Extra environment variables passed to all forked child processes. | + (section-archon-engine)= ## ArchonEngine Configuration diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 95a929a7e9..25b0449c58 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -8,7 +8,7 @@ import httpx import pytest -from areal.api.cli_args import InferenceEngineConfig, OpenAIProxyConfig +from areal.api.cli_args import InferenceEngineConfig from areal.experimental.inference_service.controller.controller import ( RolloutControllerV2, ) @@ -252,7 +252,9 @@ def test_admin_api_key_none_raises(self): RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_model_empty_raises(self): - cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key", model="") + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-key", model="" + ) with pytest.raises(ValueError, match="model must not be empty"): RolloutControllerV2(config=cfg, scheduler=MagicMock()) From b279abf6f792adda6de3fad5a4eda05a02c27296 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:18:33 +0000 Subject: [PATCH 03/12] feat(service): add runtime inference APIs to agent controller Enable AgentController to manage inference-backed sessions and export trajectories for agent-service workflows. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .../agent_service/controller/controller.py | 180 ++++++++++++++++++ .../agent_service/test_controller.py | 139 +++++++++++++- 2 files changed, 317 insertions(+), 2 deletions(-) diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py index e7465e52be..633034c38b 100644 --- a/areal/experimental/agent_service/controller/controller.py +++ b/areal/experimental/agent_service/controller/controller.py @@ -26,24 +26,30 @@ import threading import time import traceback +import uuid from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import TYPE_CHECKING, Any +import aiohttp import requests from areal.api.cli_args import AgentConfig +from areal.experimental.openai.proxy.server import deserialize_interactions from areal.utils import logging from areal.utils.network import format_hostport if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker + from areal.experimental.openai.types import InteractionWithTokenLogpReward logger = logging.getLogger("AgentController") _GUARD_ROLE = "agent-guard" _UNREGISTER_RETRIES = 3 _HEALTH_CHECK_WORKERS = 4 +_DEFAULT_RUNTIME_TIMEOUT = 600.0 +_DEFAULT_INFERENCE_ADMIN_API_KEY = "areal-admin-key" @dataclass @@ -58,6 +64,16 @@ class _WorkerPair: worker_addr: str +@dataclass +class _RuntimeSession: + agent_session_id: str + inference_gateway_addr: str + inference_admin_api_key: str + inference_session_id: str + inference_session_api_key: str + inference_model: str = "" + + class AgentController: """Orchestrator for the Agent Service micro-service stack. @@ -89,6 +105,8 @@ def __init__( self._next_pair_index: int = 0 self._forked_services: list[tuple[str, str, int]] = [] + self._sessions: dict[str, _RuntimeSession] = {} + self._sessions_lock = threading.Lock() self._health_stop = threading.Event() self._health_thread: threading.Thread | None = None @@ -217,6 +235,8 @@ def destroy(self) -> None: self._guard_addrs.clear() with self._pairs_lock: self._pairs.clear() + with self._sessions_lock: + self._sessions.clear() self._router_addr = "" self._gateway_addr = "" @@ -369,6 +389,166 @@ def pairs(self) -> dict[int, _WorkerPair]: with self._pairs_lock: return dict(self._pairs) + # ------------------------------------------------------------------ + # Runtime APIs + # ------------------------------------------------------------------ + + @staticmethod + def _bearer_headers(api_key: str) -> dict[str, str]: + return {"Authorization": f"Bearer {api_key}"} + + async def _post_json( + self, + url: str, + payload: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + *, + timeout: float = _DEFAULT_RUNTIME_TIMEOUT, + expect_json: bool = True, + ) -> dict[str, Any]: + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.post(url, json=payload, headers=headers) as resp: + resp.raise_for_status() + if not expect_json: + return {} + return await resp.json() + + async def _grant_capacity( + self, + inference_gateway_addr: str, + inference_admin_api_key: str, + ) -> None: + await self._post_json( + f"{inference_gateway_addr.rstrip('/')}/grant_capacity", + headers=self._bearer_headers(inference_admin_api_key), + expect_json=False, + ) + + async def start_session( + self, + task_id: str, + *, + inference_gateway_addr: str, + inference_admin_api_key: str = _DEFAULT_INFERENCE_ADMIN_API_KEY, + inference_model: str = "", + api_key: str | None = None, + ) -> dict[str, str]: + agent_session_id = f"agent-sess-{uuid.uuid4().hex[:12]}" + normalized_task_id = task_id or agent_session_id + gateway_addr = inference_gateway_addr.rstrip("/") + + await self._grant_capacity(gateway_addr, inference_admin_api_key) + + payload: dict[str, Any] = {"task_id": normalized_task_id} + if api_key is not None: + payload["api_key"] = api_key + + data = await self._post_json( + f"{gateway_addr}/rl/start_session", + payload=payload, + headers=self._bearer_headers(inference_admin_api_key), + ) + + session = _RuntimeSession( + agent_session_id=agent_session_id, + inference_gateway_addr=gateway_addr, + inference_admin_api_key=inference_admin_api_key, + inference_session_id=data["session_id"], + inference_session_api_key=data["api_key"], + inference_model=inference_model, + ) + with self._sessions_lock: + self._sessions[agent_session_id] = session + + return { + "session_id": agent_session_id, + "inference_session_id": session.inference_session_id, + "api_key": session.inference_session_api_key, + } + + async def step( + self, + input: str | list[dict[str, Any]], + session_id: str, + *, + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if not self._gateway_addr: + raise RuntimeError("step() requires the agent-service gateway to be running") + + session = self._resolve_session(session_id) + input_items = ( + [{"type": "message", "content": input}] + if isinstance(input, str) + else input + ) + + merged_metadata: dict[str, Any] = { + "inference_base_url": session.inference_gateway_addr, + "inference_api_key": session.inference_session_api_key, + } + if session.inference_model: + merged_metadata["inference_model"] = session.inference_model + if metadata: + merged_metadata.update(metadata) + + body: dict[str, Any] = { + "input": input_items, + "model": (session.inference_model or "default").replace("/", "--"), + "user": session.agent_session_id, + } + if merged_metadata: + body["metadata"] = merged_metadata + + return await self._post_json( + f"{self._gateway_addr}/v1/responses", + payload=body, + headers=self._bearer_headers(self.config.admin_api_key), + ) + + async def set_reward( + self, + reward: float, + session_id: str, + *, + interaction_id: str | None = None, + ) -> dict[str, Any]: + session = self._resolve_session(session_id) + return await self._post_json( + f"{session.inference_gateway_addr}/rl/set_reward", + payload={"interaction_id": interaction_id, "reward": reward}, + headers=self._bearer_headers(session.inference_session_api_key), + ) + + async def export_trajectory( + self, + session_id: str, + *, + trajectory_id: int | None = None, + discount: float = 1.0, + style: str = "individual", + ) -> dict[str, InteractionWithTokenLogpReward]: + session = self._resolve_session(session_id) + data = await self._post_json( + f"{session.inference_gateway_addr}/export_trajectories", + payload={ + "session_id": session.inference_session_id, + "trajectory_id": trajectory_id, + "discount": discount, + "style": style, + }, + headers=self._bearer_headers(session.inference_admin_api_key), + ) + return deserialize_interactions(data["interactions"]) + + def _resolve_session(self, session_id: str) -> _RuntimeSession: + with self._sessions_lock: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(f"Unknown session_id: {session_id!r}") + return session + # ------------------------------------------------------------------ # Guard interaction helpers # ------------------------------------------------------------------ diff --git a/tests/experimental/agent_service/test_controller.py b/tests/experimental/agent_service/test_controller.py index bacbd1bbe9..8a36a24540 100644 --- a/tests/experimental/agent_service/test_controller.py +++ b/tests/experimental/agent_service/test_controller.py @@ -10,12 +10,15 @@ import time from dataclasses import dataclass -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from areal.api.cli_args import AgentConfig -from areal.experimental.agent_service.controller.controller import AgentController +from areal.experimental.agent_service.controller.controller import ( + AgentController, + _RuntimeSession, +) CTRL = "areal.experimental.agent_service.controller.controller" @@ -334,3 +337,135 @@ def test_health_monitor_disabled_when_interval_zero(self, mock_requests, config) assert ctrl._health_thread is None ctrl.destroy() + + +class TestRuntimeAPIs: + @pytest.mark.asyncio + async def test_start_session_grants_capacity_and_stores_session(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._grant_capacity = AsyncMock() + ctrl._post_json = AsyncMock( + return_value={"session_id": "inf-sess-1", "api_key": "sess-key"} + ) + + session = await ctrl.start_session( + "task-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_model="Qwen/Test", + ) + + assert session["session_id"].startswith("agent-sess-") + assert session["inference_session_id"] == "inf-sess-1" + assert session["api_key"] == "sess-key" + ctrl._grant_capacity.assert_awaited_once_with( + "http://inference", "rollout-admin" + ) + ctrl._post_json.assert_awaited_once_with( + "http://inference/rl/start_session", + payload={"task_id": "task-1"}, + headers={"Authorization": "Bearer rollout-admin"}, + ) + + stored = ctrl._resolve_session(session["session_id"]) + assert stored.inference_session_id == "inf-sess-1" + assert stored.inference_session_api_key == "sess-key" + assert stored.inference_model == "Qwen/Test" + + @pytest.mark.asyncio + async def test_step_posts_async_gateway_request(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._gateway_addr = "http://agent-gateway" + ctrl._post_json = AsyncMock(return_value={"status": "completed"}) + ctrl._sessions["agent-sess-1"] = _RuntimeSession( + agent_session_id="agent-sess-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_session_id="inf-sess-1", + inference_session_api_key="sess-key", + inference_model="Qwen/Test", + ) + + result = await ctrl.step( + "hello", + "agent-sess-1", + metadata={"extra": "value"}, + ) + + assert result == {"status": "completed"} + ctrl._post_json.assert_awaited_once_with( + "http://agent-gateway/v1/responses", + payload={ + "input": [{"type": "message", "content": "hello"}], + "model": "Qwen--Test", + "user": "agent-sess-1", + "metadata": { + "inference_base_url": "http://inference", + "inference_api_key": "sess-key", + "inference_model": "Qwen/Test", + "extra": "value", + }, + }, + headers={"Authorization": "Bearer test-key"}, + ) + + @pytest.mark.asyncio + async def test_set_reward_uses_session_api_key(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._post_json = AsyncMock(return_value={"trajectory_id": 7}) + ctrl._sessions["agent-sess-1"] = _RuntimeSession( + agent_session_id="agent-sess-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_session_id="inf-sess-1", + inference_session_api_key="sess-key", + ) + + result = await ctrl.set_reward(1.0, "agent-sess-1", interaction_id="resp-1") + + assert result == {"trajectory_id": 7} + ctrl._post_json.assert_awaited_once_with( + "http://inference/rl/set_reward", + payload={"interaction_id": "resp-1", "reward": 1.0}, + headers={"Authorization": "Bearer sess-key"}, + ) + + @pytest.mark.asyncio + @patch(f"{CTRL}.deserialize_interactions") + async def test_export_trajectory_deserializes_response( + self, mock_deserialize, config + ): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentController(config=config, scheduler=scheduler) + ctrl._post_json = AsyncMock(return_value={"interactions": {"k": "v"}}) + ctrl._sessions["agent-sess-1"] = _RuntimeSession( + agent_session_id="agent-sess-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_session_id="inf-sess-1", + inference_session_api_key="sess-key", + ) + mock_deserialize.return_value = {"interaction-1": MagicMock(reward=1.0)} + + result = await ctrl.export_trajectory( + "agent-sess-1", + trajectory_id=5, + discount=0.9, + style="individual", + ) + + assert "interaction-1" in result + ctrl._post_json.assert_awaited_once_with( + "http://inference/export_trajectories", + payload={ + "session_id": "inf-sess-1", + "trajectory_id": 5, + "discount": 0.9, + "style": "individual", + }, + headers={"Authorization": "Bearer rollout-admin"}, + ) + mock_deserialize.assert_called_once_with({"k": "v"}) From aaa07f48f164c634731e99d82dace6757c9a9c6e Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:18:43 +0000 Subject: [PATCH 04/12] feat(examples): add experimental claude agent service package Move the Claude agent-service example into the experimental examples package so it can be imported from the new nested module path. Key changes: - add experimental package markers - add claude worker implementation - add claude interactive launcher Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- examples/experimental/__init__.py | 1 + .../experimental/agent_service/__init__.py | 1 + .../agent_service/claude/__init__.py | 1 + .../agent_service/claude/agent.py | 120 ++++++++++++++++++ .../agent_service/claude/run_agent_service.py | 110 ++++++++++++++++ 5 files changed, 233 insertions(+) create mode 100644 examples/experimental/__init__.py create mode 100644 examples/experimental/agent_service/__init__.py create mode 100644 examples/experimental/agent_service/claude/__init__.py create mode 100644 examples/experimental/agent_service/claude/agent.py create mode 100644 examples/experimental/agent_service/claude/run_agent_service.py diff --git a/examples/experimental/__init__.py b/examples/experimental/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/experimental/agent_service/__init__.py b/examples/experimental/agent_service/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/agent_service/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/experimental/agent_service/claude/__init__.py b/examples/experimental/agent_service/claude/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/agent_service/claude/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/experimental/agent_service/claude/agent.py b/examples/experimental/agent_service/claude/agent.py new file mode 100644 index 0000000000..2b3a3d01a3 --- /dev/null +++ b/examples/experimental/agent_service/claude/agent.py @@ -0,0 +1,120 @@ +"""Claude Agent for AReaL Agent Service. + +Implements :class:`AgentRunnable` using the Claude Agent SDK +(``claude-agent-sdk``). Each Worker instance holds a pool of +:class:`ClaudeSDKClient` sessions keyed by ``session_key``, so multi-turn +conversations preserve full context without re-sending history. +""" + +from __future__ import annotations + +import os +from typing import Any, Literal + +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ClaudeSDKClient, + ResultMessage, + TextBlock, + ToolUseBlock, +) + +from areal.experimental.agent_service.types import ( + AgentRequest, + AgentResponse, + EventEmitter, +) +from areal.utils import logging + +logger = logging.getLogger("ClaudeAgent") + +PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"] + +_DEFAULT_PERMISSION_MODE: PermissionMode = "bypassPermissions" + + +class ClaudeAgent: + """AgentRunnable backed by the Claude Agent SDK.""" + + def __init__(self, **kwargs: Any) -> None: + del kwargs + self._model = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6") + self._system_prompt = os.environ.get("CLAUDE_SYSTEM_PROMPT", "") + self._max_turns = int(os.environ.get("CLAUDE_MAX_TURNS", "20")) + self._permission_mode: PermissionMode = _DEFAULT_PERMISSION_MODE + self._sessions: dict[str, ClaudeSDKClient] = {} + + logger.info( + "ClaudeAgent initialized (model=%s, max_turns=%d)", + self._model, + self._max_turns, + ) + + def _make_options(self) -> ClaudeAgentOptions: + opts = ClaudeAgentOptions( + model=self._model, + max_turns=self._max_turns, + permission_mode=self._permission_mode, + ) + if self._system_prompt: + opts.system_prompt = self._system_prompt + return opts + + async def _get_or_create_client(self, session_key: str) -> ClaudeSDKClient: + if session_key not in self._sessions: + client = ClaudeSDKClient(options=self._make_options()) + await client.__aenter__() + self._sessions[session_key] = client + logger.info("New session: %s", session_key) + return self._sessions[session_key] + + async def close_session(self, session_key: str) -> None: + client = self._sessions.pop(session_key, None) + if client is not None: + try: + await client.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing session %s", session_key, exc_info=True) + + async def close_all_sessions(self) -> None: + for key in list(self._sessions): + await self.close_session(key) + + async def run( + self, + request: AgentRequest, + *, + emitter: EventEmitter, + ) -> AgentResponse: + client = await self._get_or_create_client(request.session_key) + + try: + await client.query(request.message) + + text_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + await emitter.emit_delta(block.text) + text_parts.append(block.text) + elif isinstance(block, ToolUseBlock): + await emitter.emit_tool_call( + name=block.name, + args=str(block.input), + ) + tool_calls.append( + {"name": block.name, "input": block.input} + ) + elif isinstance(msg, ResultMessage): + break + + return AgentResponse( + summary="".join(text_parts)[:200], + metadata={"tool_calls": tool_calls}, + ) + except Exception: + await self.close_session(request.session_key) + raise diff --git a/examples/experimental/agent_service/claude/run_agent_service.py b/examples/experimental/agent_service/claude/run_agent_service.py new file mode 100644 index 0000000000..77aa7db577 --- /dev/null +++ b/examples/experimental/agent_service/claude/run_agent_service.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Launch the Agent Service with Claude Agent SDK.""" + +from __future__ import annotations + +import argparse +import asyncio +import time + +import httpx + +from areal.api.cli_args import AgentConfig +from areal.experimental.agent_service.controller import AgentController + + +async def _wait_healthy(url: str, timeout: float = 60.0) -> None: + async with httpx.AsyncClient() as client: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = await client.get(url) + if resp.status_code == 200: + return + except httpx.ConnectError: + pass + await asyncio.sleep(0.5) + raise TimeoutError(f"Service at {url} did not become healthy") + + +async def interactive_loop(gateway_addr: str, admin_key: str) -> None: + session_key = f"session-{int(time.time())}" + print("Type your message (or 'quit' to exit):\n") + + async with httpx.AsyncClient(timeout=120.0) as client: + while True: + try: + user_input = input("You: ") + except (EOFError, KeyboardInterrupt): + break + if user_input.strip().lower() in {"quit", "exit", "q"}: + break + if not user_input.strip(): + continue + + resp = await client.post( + f"{gateway_addr}/v1/responses", + json={ + "input": [{"type": "message", "content": user_input}], + "model": "claude-agent", + "user": session_key, + }, + headers={"Authorization": f"Bearer {admin_key}"}, + ) + data = resp.json() + + if data.get("status") == "completed": + for item in data.get("output", []): + if item.get("type") == "message": + for block in item.get("content", []): + if block.get("type") == "output_text": + print(f"Agent: {block['text']}") + elif item.get("type") == "function_call": + print(f"[tool] {item.get('name', '')}") + print() + elif data.get("error"): + print(f"Error: {data['error'].get('message', '')[:200]}\n") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Agent Service — Claude Agent SDK") + parser.add_argument("--num-pairs", type=int, default=1) + parser.add_argument("--admin-api-key", default="areal-agent-admin") + args = parser.parse_args() + + from areal.infra.scheduler.local import LocalScheduler + + scheduler = LocalScheduler( + experiment_name="agent-service-demo", + trial_name="run0", + gpu_devices=[], + ) + + ctrl = AgentController( + config=AgentConfig( + agent_cls_path="examples.experimental.agent_service.claude.agent.ClaudeAgent", + admin_api_key=args.admin_api_key, + num_pairs=args.num_pairs, + ), + scheduler=scheduler, + ) + + try: + print(f"Initializing with {args.num_pairs} pair(s) ...") + ctrl.initialize() + print(f" Router: {ctrl.router_addr}") + print(f" Gateway: {ctrl.gateway_addr}") + print(f" Pairs: {len(ctrl.pairs)}") + + asyncio.run(_wait_healthy(f"{ctrl.gateway_addr}/health")) + print("All services ready.\n") + asyncio.run(interactive_loop(ctrl.gateway_addr, admin_key=args.admin_api_key)) + finally: + print("\nShutting down ...") + ctrl.destroy() + print("Done.") + + +if __name__ == "__main__": + main() From fe41c89fed910bb1319971aa290c8d0ca9b5eefe Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:18:50 +0000 Subject: [PATCH 05/12] refactor(examples): move claude agent service entrypoints Retire the old examples/agent_service layout and point the worker CLI at the new experimental Claude module path. Key changes: - update worker example import path - remove legacy claude example files Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .../agent_service/worker/__main__.py | 2 +- examples/agent_service/README.md | 114 -------------- examples/agent_service/agent.py | 139 ------------------ examples/agent_service/run_agent_service.py | 129 ---------------- 4 files changed, 1 insertion(+), 383 deletions(-) delete mode 100644 examples/agent_service/README.md delete mode 100644 examples/agent_service/agent.py delete mode 100644 examples/agent_service/run_agent_service.py diff --git a/areal/experimental/agent_service/worker/__main__.py b/areal/experimental/agent_service/worker/__main__.py index c14d52eba7..8a2ec97654 100644 --- a/areal/experimental/agent_service/worker/__main__.py +++ b/areal/experimental/agent_service/worker/__main__.py @@ -6,7 +6,7 @@ via Guard to create Worker+DataProxy pairs. python -m areal.experimental.agent_service.worker \ - --agent examples.agent_service.agent.ClaudeAgent \ + --agent examples.experimental.agent_service.claude.agent.ClaudeAgent \ --host 127.0.0.1 --port 9000 """ diff --git a/examples/agent_service/README.md b/examples/agent_service/README.md deleted file mode 100644 index 1c7e8bc839..0000000000 --- a/examples/agent_service/README.md +++ /dev/null @@ -1,114 +0,0 @@ -# Agent Service — Claude Agent SDK - -## Overview - -This example demonstrates AReaL's Agent Service running the **Claude Agent SDK** -(`claude-agent-sdk`) as a scalable HTTP micro-service. It turns Claude's autonomous -agent capabilities — multi-turn conversations, tool use, file editing, web search — into -a production-deployable service with session management, load balancing, and dynamic -scaling. - -**Why this matters**: Projects like -[claude-agent-acp](https://github.com/agentclientprotocol/claude-agent-acp) expose -Claude Agent SDK via custom protocols (ACP) for editor integration. AReaL takes a -different approach — it wraps Claude Agent SDK into standard HTTP micro-services with -session-affine routing, so you can **scale, orchestrate, and train** Claude agents using -AReaL's RL infrastructure. - -``` -Client → Gateway (HTTP) → Router → DataProxy (session state) → Worker (ClaudeSDKClient) -``` - -## Prerequisites - -```bash -uv pip install claude-agent-sdk -export ANTHROPIC_API_KEY=sk-... -``` - -## Quick Start - -```bash -python examples/agent_service/run_agent_service.py -``` - -The script creates a `LocalScheduler`, launches Guard workers, then forks Router → -Worker+DataProxy → Gateway. An interactive prompt lets you chat with the Claude agent. - -### Options - -```bash -python examples/agent_service/run_agent_service.py --num-pairs 4 -``` - -### Send requests directly - -```bash -curl -X POST http://localhost:/v1/responses \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer areal-agent-admin" \ - -d '{ - "input": [{"type": "message", "content": "Explain RLHF in simple terms"}], - "model": "claude-agent", - "user": "my-session" - }' -``` - -## Configuration - -Claude Agent SDK settings are controlled via environment variables: - -| Variable | Default | Description | -| ---------------------- | ------------------- | --------------------------- | -| `ANTHROPIC_API_KEY` | (required) | Anthropic API key | -| `CLAUDE_MODEL` | `claude-sonnet-4-6` | Model to use | -| `CLAUDE_SYSTEM_PROMPT` | (none) | Optional system prompt | -| `CLAUDE_MAX_TURNS` | `20` | Max agentic turns per query | - -## Architecture - -The Worker maintains a **session-persistent `ClaudeSDKClient`** per session key. Unlike -stateless wrappers, the SDK's internal session retains the full conversation transcript -— no need to re-send history on each turn. - -``` -Turn 1: Client → Gateway → Router → DataProxy → Worker - Worker: creates ClaudeSDKClient for session "abc" - Claude Agent SDK runs autonomously (tool calls, file ops, etc.) - Response streams back through the chain - -Turn 2: Client → Gateway → Router (same DataProxy) → DataProxy → Worker - Worker: reuses ClaudeSDKClient for session "abc" - SDK remembers full context from Turn 1 -``` - -## Programmatic Usage - -```python -from areal.experimental.agent_service.controller import ( - AgentController, -) -from areal.api.cli_args import AgentConfig -from areal.infra.scheduler.local import LocalScheduler - -scheduler = LocalScheduler(experiment_name="demo", trial_name="run0", gpu_devices=[]) -ctrl = AgentController( - config=AgentConfig( - agent_cls_path="examples.agent_service.agent.ClaudeAgent", - num_pairs=2, - ), - scheduler=scheduler, -) -ctrl.initialize() -# ctrl.gateway_addr → "http://10.0.0.1:9005" -# ctrl.scale_up(2) → add 2 more pairs -# ctrl.scale_down(1) → remove 1 pair (with graceful drain) -ctrl.destroy() -``` - -## Files - -| File | Description | -| ---------------------- | ----------------------------------------------------------- | -| `agent.py` | `ClaudeAgent` — session-persistent Claude Agent SDK wrapper | -| `run_agent_service.py` | Controller-based launcher + interactive conversation | diff --git a/examples/agent_service/agent.py b/examples/agent_service/agent.py deleted file mode 100644 index c05f3bebe5..0000000000 --- a/examples/agent_service/agent.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Claude Agent for AReaL Agent Service. - -Implements :class:`AgentRunnable` using the Claude Agent SDK -(``claude-agent-sdk``). Each Worker instance holds a pool of -:class:`ClaudeSDKClient` sessions keyed by ``session_key``, so multi-turn -conversations preserve full context without re-sending history. - -Requires:: - - pip install claude-agent-sdk - -Environment variables: - ANTHROPIC_API_KEY — Anthropic API key (required) - CLAUDE_MODEL — model name (default: claude-sonnet-4-6) - CLAUDE_SYSTEM_PROMPT — optional system prompt override - CLAUDE_MAX_TURNS — max agentic turns per query (default: 20) -""" - -from __future__ import annotations - -import os -from typing import Any, Literal - -from claude_agent_sdk import ( - AssistantMessage, - ClaudeAgentOptions, - ClaudeSDKClient, - ResultMessage, - TextBlock, - ToolUseBlock, -) - -from areal.experimental.agent_service.types import ( - AgentRequest, - AgentResponse, - EventEmitter, -) -from areal.utils import logging - -logger = logging.getLogger("ClaudeAgent") - -PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"] - -_DEFAULT_PERMISSION_MODE: PermissionMode = "bypassPermissions" - - -class ClaudeAgent: - """AgentRunnable backed by the Claude Agent SDK. - - Maintains a ``ClaudeSDKClient`` per session for true multi-turn - continuity — the SDK's internal session keeps the full transcript, - so ``request.history`` is only used for the very first turn of a - new session (to seed context if provided by the caller). - """ - - def __init__(self, **kwargs: Any) -> None: - self._model = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6") - self._system_prompt = os.environ.get("CLAUDE_SYSTEM_PROMPT", "") - self._max_turns = int(os.environ.get("CLAUDE_MAX_TURNS", "20")) - self._permission_mode: PermissionMode = _DEFAULT_PERMISSION_MODE - - self._sessions: dict[str, ClaudeSDKClient] = {} - - logger.info( - "ClaudeAgent initialized (model=%s, max_turns=%d)", - self._model, - self._max_turns, - ) - - def _make_options(self) -> ClaudeAgentOptions: - opts = ClaudeAgentOptions( - model=self._model, - max_turns=self._max_turns, - permission_mode=self._permission_mode, - ) - if self._system_prompt: - opts.system_prompt = self._system_prompt - return opts - - async def _get_or_create_client(self, session_key: str) -> ClaudeSDKClient: - if session_key not in self._sessions: - client = ClaudeSDKClient(options=self._make_options()) - await client.__aenter__() - self._sessions[session_key] = client - logger.info("New session: %s", session_key) - return self._sessions[session_key] - - async def close_session(self, session_key: str) -> None: - client = self._sessions.pop(session_key, None) - if client is not None: - try: - await client.__aexit__(None, None, None) - except Exception: - logger.warning("Error closing session %s", session_key, exc_info=True) - - async def close_all_sessions(self) -> None: - keys = list(self._sessions.keys()) - for key in keys: - await self.close_session(key) - - async def run( - self, - request: AgentRequest, - *, - emitter: EventEmitter, - ) -> AgentResponse: - client = await self._get_or_create_client(request.session_key) - - try: - await client.query(request.message) - - text_parts: list[str] = [] - tool_calls: list[dict[str, Any]] = [] - - async for msg in client.receive_response(): - if isinstance(msg, AssistantMessage): - for block in msg.content: - if isinstance(block, TextBlock): - await emitter.emit_delta(block.text) - text_parts.append(block.text) - elif isinstance(block, ToolUseBlock): - await emitter.emit_tool_call( - name=block.name, - args=str(block.input), - ) - tool_calls.append( - {"name": block.name, "input": block.input} - ) - elif isinstance(msg, ResultMessage): - break - - summary = "".join(text_parts) - return AgentResponse( - summary=summary[:200], - metadata={"tool_calls": tool_calls}, - ) - except Exception: - await self.close_session(request.session_key) - raise diff --git a/examples/agent_service/run_agent_service.py b/examples/agent_service/run_agent_service.py deleted file mode 100644 index 80850dc4ca..0000000000 --- a/examples/agent_service/run_agent_service.py +++ /dev/null @@ -1,129 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Launch the Agent Service with Claude Agent SDK. - -Usage:: - - python examples/agent_service/run_agent_service.py - python examples/agent_service/run_agent_service.py --num-pairs 2 - -Requires:: - - uv pip install claude-agent-sdk - export ANTHROPIC_API_KEY=sk-... -""" - -from __future__ import annotations - -import argparse -import asyncio -import time - -import httpx - -from areal.api.cli_args import AgentConfig -from areal.experimental.agent_service.controller import AgentController - - -async def _wait_healthy(url: str, timeout: float = 60.0) -> None: - async with httpx.AsyncClient() as client: - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - try: - resp = await client.get(url) - if resp.status_code == 200: - return - except httpx.ConnectError: - pass - await asyncio.sleep(0.5) - raise TimeoutError(f"Service at {url} did not become healthy") - - -async def interactive_loop(gateway_addr: str, admin_key: str) -> None: - session_key = f"session-{int(time.time())}" - print("Type your message (or 'quit' to exit):\n") - - async with httpx.AsyncClient(timeout=120.0) as client: - while True: - try: - user_input = input("You: ") - except (EOFError, KeyboardInterrupt): - break - if user_input.strip().lower() in ("quit", "exit", "q"): - break - if not user_input.strip(): - continue - - resp = await client.post( - f"{gateway_addr}/v1/responses", - json={ - "input": [{"type": "message", "content": user_input}], - "model": "claude-agent", - "user": session_key, - }, - headers={"Authorization": f"Bearer {admin_key}"}, - ) - data = resp.json() - - if data.get("status") == "completed": - for item in data.get("output", []): - if item.get("type") == "message": - for block in item.get("content", []): - if block.get("type") == "output_text": - print(f"Agent: {block['text']}") - elif item.get("type") == "function_call": - print(f"[tool] {item.get('name', '')}") - print() - elif data.get("error"): - print(f"Error: {data['error'].get('message', '')[:200]}\n") - - -def main() -> None: - parser = argparse.ArgumentParser(description="Agent Service — Claude Agent SDK") - parser.add_argument( - "--num-pairs", - type=int, - default=1, - help="Number of Worker+DataProxy pairs (default: 1)", - ) - parser.add_argument( - "--admin-api-key", - default="areal-agent-admin", - help="Admin API key for inter-service auth", - ) - args = parser.parse_args() - - from areal.infra.scheduler.local import LocalScheduler - - scheduler = LocalScheduler( - experiment_name="agent-service-demo", - trial_name="run0", - gpu_devices=[], - ) - - ctrl_config = AgentConfig( - agent_cls_path="examples.agent_service.agent.ClaudeAgent", - admin_api_key=args.admin_api_key, - num_pairs=args.num_pairs, - ) - ctrl = AgentController(config=ctrl_config, scheduler=scheduler) - - try: - print(f"Initializing with {args.num_pairs} pair(s) ...") - ctrl.initialize() - print(f" Router: {ctrl.router_addr}") - print(f" Gateway: {ctrl.gateway_addr}") - print(f" Pairs: {len(ctrl.pairs)}") - - asyncio.run(_wait_healthy(f"{ctrl.gateway_addr}/health")) - print("All services ready.\n") - - asyncio.run(interactive_loop(ctrl.gateway_addr, admin_key=args.admin_api_key)) - finally: - print("\nShutting down ...") - ctrl.destroy() - print("Done.") - - -if __name__ == "__main__": - main() From 0397381cfeb292f5c264802adafa38568988229f Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:18:57 +0000 Subject: [PATCH 06/12] docs(service): document agent service runtime workflows Describe the new AgentController runtime APIs and document the relocated experimental Claude and tau2 examples. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- areal/experimental/agent_service/README.md | 55 ++++++++++- examples/experimental/agent_service/README.md | 91 +++++++++++++++++++ 2 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 examples/experimental/agent_service/README.md diff --git a/areal/experimental/agent_service/README.md b/areal/experimental/agent_service/README.md index 30f76bb8ec..5543de01a9 100644 --- a/areal/experimental/agent_service/README.md +++ b/areal/experimental/agent_service/README.md @@ -4,7 +4,10 @@ The Agent Service provides **agent-level** capabilities on top of AReaL's model-level proxy. It exposes complete agent sessions — multi-turn conversations with tool use, -memory, and pluggable agent frameworks — via independent HTTP microservices. +memory, and pluggable agent frameworks — via independent HTTP microservices. It also +includes an `AgentController` that can launch the stack through Guard processes and +bridge agent conversations to the experimental inference service for RL data +collection. ## Architecture @@ -47,6 +50,10 @@ at startup. Each `POST /run` request is a single turn — the agent receives the conversation history in the request and returns a response. The Worker has no session state. +**AgentController** — Python orchestrator that launches Guards via the scheduler, forks +the Router / Gateway / Worker+DataProxy pairs onto them, supports scale-up and +scale-down, and exposes async runtime APIs for inference-backed RL sessions. + ## Agent Protocol Any class that satisfies the `AgentRunnable` protocol can run on the Worker: @@ -129,6 +136,29 @@ class EventEmitter(Protocol): | `/ws` | WS | Gateway WebSocket protocol | | `/v1/responses` | POST | OpenResponses HTTP bridge | +## AgentController Runtime APIs + +`AgentController` is the integration point used by the examples and rollout workflows. +It manages the agent-service stack and exposes async helpers for RL/inference flows: + +| Method | Description | +| ------ | ----------- | +| `initialize()` | Launch Guards, Router, Worker+DataProxy pairs, Gateway, and the health monitor | +| `destroy()` | Tear down the full stack in reverse order | +| `scale_up(count)` | Add Worker+DataProxy pairs | +| `scale_down(count)` | Unregister, drain, and remove pairs | +| `start_session(...)` | Grant inference capacity and create an RL session bound to an agent session | +| `step(input, session_id, metadata=None)` | Send a turn through the agent-service Gateway `POST /v1/responses` | +| `set_reward(reward, session_id, interaction_id=None)` | Forward the final reward to the inference service | +| `export_trajectory(session_id, ...)` | Export serialized interactions from the inference service | + +Typical rollout flow: + +1. `start_session()` to create the agent/inference session pair. +2. `step()` for each user turn. +3. `set_reward()` when the episode completes. +4. `export_trajectory()` to retrieve interactions for training. + ## Multi-turn Conversation Flow ``` @@ -189,8 +219,25 @@ areal/experimental/agent_service/ ├── app.py # create_worker_app() └── config.py # WorkerConfig dataclass -examples/agent_service/ -├── agent.py # ClaudeAgent (Claude Agent SDK) -├── run_agent_service.py # Controller-based launcher + interactive demo +examples/experimental/agent_service/ +├── __init__.py # Marks the examples package +├── claude/ +│ ├── __init__.py # Claude example package +│ ├── agent.py # ClaudeAgent (Claude Agent SDK) +│ └── run_agent_service.py # Controller-based launcher + interactive demo +├── tau2/ +│ ├── __init__.py # Tau2 example package +│ ├── agent.py # Tau2 agent-service worker example +│ ├── workflow.py # Tau2 workflow using async controller APIs +│ ├── run_rollout.py # Direct rollout driver for Tau2 +│ └── config.yaml # Tau2 example config └── README.md # Example documentation ``` + +For a standalone worker process, the agent import path now points at the nested Claude +example module: + +```bash +python -m areal.experimental.agent_service.worker \ + --agent examples.experimental.agent_service.claude.agent.ClaudeAgent +``` diff --git a/examples/experimental/agent_service/README.md b/examples/experimental/agent_service/README.md new file mode 100644 index 0000000000..967fc27fde --- /dev/null +++ b/examples/experimental/agent_service/README.md @@ -0,0 +1,91 @@ +# Agent Service Examples + +## Overview + +This directory contains experimental examples built on top of AReaL's agent service. +The examples are grouped by scenario: + +- `claude/` — a standalone Claude Agent SDK service demo +- `tau2/` — a tau2 customer-service rollout example that combines the agent service + with the experimental inference service + +The agent service exposes complete agent sessions through Router, DataProxy, Worker, +and Gateway microservices, and can be paired with the experimental inference service +for RL data collection. + +## Example 1: Claude Agent SDK Service + +This is the Claude Agent SDK example under the new `claude/` subdirectory. + +### Prerequisites + +```bash +uv pip install claude-agent-sdk +export ANTHROPIC_API_KEY=sk-... +``` + +### Run + +```bash +python examples/experimental/agent_service/claude/run_agent_service.py +python examples/experimental/agent_service/claude/run_agent_service.py --num-pairs 4 +``` + +The script creates a `LocalScheduler`, launches Guard workers, then forks Router, +Worker+DataProxy pairs, and Gateway. An interactive prompt lets you chat with the +Claude agent through `POST /v1/responses`. + +Files: + +- `claude/agent.py` — Claude Agent SDK worker implementation +- `claude/run_agent_service.py` — interactive launcher for the Claude example + +## Example 2: Tau2 Agent Service Rollout + +This example runs the tau2 customer-service agent inside the experimental agent service +while the experimental inference service collects RL trajectories. Unlike the reference +inference-service example, this script initializes `RolloutControllerV2` but does not +use `rollout_batch()`. It directly runs the tau2 workflow and returns exported +trajectories from the inference service. + +### Additional Prerequisites + +```bash +pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion +pip install pydantic-ai +export TAU2_DATA_DIR=/path/to/tau2-bench/data +``` + +If `econfig.solo_mode=false`, also start a user simulator model and set +`econfig.user_llm_base_url` in `tau2/config.yaml`. + +### Run + +```bash +python examples/experimental/agent_service/tau2/run_rollout.py \ + --config examples/experimental/agent_service/tau2/config.yaml \ + cluster.fileroot= \ + cluster.name_resolve.nfs_record_root= +``` + +### What it does + +1. Starts the experimental inference service with `RolloutControllerV2`. +2. Starts the experimental agent service with `AgentController`. +3. For each tau2 task, the workflow: + - calls `AgentController.start_session()` (which grants capacity and starts the RL + session), + - drives the tau2 conversation through `AgentController.step()`, + - calls `AgentController.set_reward()`, + - calls `AgentController.export_trajectory()` and returns the exported interactions. + +### Files + +| File | Description | +| --- | --- | +| `claude/agent.py` | Claude Agent SDK example agent | +| `claude/run_agent_service.py` | Interactive launcher for the Claude agent service | +| `tau2/agent.py` | Tau2 agent-service worker agent | +| `tau2/workflow.py` | Tau2 rollout workflow using async controller APIs | +| `tau2/run_rollout.py` | Direct rollout driver for the tau2 workflow | +| `tau2/config.yaml` | Example config for the tau2 rollout driver | From 96c451f4747b5e9d64fbd4bd0f9889bb802ca926 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:19:06 +0000 Subject: [PATCH 07/12] feat(examples): add tau2 agent service workflow modules Introduce the tau2 agent and workflow modules that bridge the agent service with inference-backed customer-service episodes. Key changes: - add tau2 AgentRunnable implementation - add tau2 workflow session orchestration Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .../agent_service/tau2/__init__.py | 1 + .../experimental/agent_service/tau2/agent.py | 223 ++++++++++++++++ .../agent_service/tau2/workflow.py | 249 ++++++++++++++++++ 3 files changed, 473 insertions(+) create mode 100644 examples/experimental/agent_service/tau2/__init__.py create mode 100644 examples/experimental/agent_service/tau2/agent.py create mode 100644 examples/experimental/agent_service/tau2/workflow.py diff --git a/examples/experimental/agent_service/tau2/__init__.py b/examples/experimental/agent_service/tau2/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/examples/experimental/agent_service/tau2/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/experimental/agent_service/tau2/agent.py b/examples/experimental/agent_service/tau2/agent.py new file mode 100644 index 0000000000..404cec5c16 --- /dev/null +++ b/examples/experimental/agent_service/tau2/agent.py @@ -0,0 +1,223 @@ +"""Tau2 agent for the experimental agent service.""" + +from __future__ import annotations + +import inspect +import json +import os +from typing import Any + +from areal.experimental.agent_service.types import ( + AgentRequest, + AgentResponse, + EventEmitter, +) +from areal.utils import logging + +logger = logging.getLogger("Tau2Agent") + + +def _make_pydantic_tool(tau2_tool: Any): + fn = tau2_tool._func # noqa: SLF001 + name = tau2_tool.name + schema = getattr(tau2_tool, "openai_schema", {}) or {} + doc = schema.get("function", {}).get("description", name) + + async def _wrapper(**kwargs: Any) -> str: + try: + result = fn(**kwargs) + except Exception as exc: + result = f"Tool error: {exc}" + if not isinstance(result, str): + result = json.dumps(result, default=str) + return result + + _wrapper.__name__ = name + _wrapper.__qualname__ = name + _wrapper.__doc__ = doc + sig = inspect.signature(fn) + _wrapper.__signature__ = inspect.Signature( + [ + inspect.Parameter( + pname, + kind=inspect.Parameter.KEYWORD_ONLY, + default=param.default, + annotation=param.annotation, + ) + for pname, param in sig.parameters.items() + ] + ) + if hasattr(fn, "__annotations__"): + _wrapper.__annotations__ = { + k: v for k, v in fn.__annotations__.items() if k != "return" + } + return _wrapper + + +def _think_tool_fn(thoughts: str) -> str: + del thoughts + return "Your thoughts are recorded. Please continue your work." + + +class Tau2Agent: + """AgentRunnable that wraps a PydanticAI agent with tau2 tools.""" + + def __init__(self, config: dict[str, Any] | None = None, **kwargs: Any) -> None: + del kwargs + try: + from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel + from pydantic_ai.providers.openai import OpenAIProvider + from tau2.environment.tool import Tool as Tau2Tool + from tau2.registry import registry + except ImportError as exc: + raise ImportError( + "Tau2 agent service example requires 'pydantic-ai' and 'tau2-bench'" + ) from exc + + config = config or {} + tau2_cfg = config.get("tau2", {}) + agent_llm_cfg = config.get("agent_llm", {}) + + self._domain = tau2_cfg.get("domain") or os.environ.get( + "TAU2_DOMAIN", "airline" + ) + add_thinking = tau2_cfg.get("add_thinking_tool", False) + + data_dir = tau2_cfg.get("data_dir") or os.environ.get("TAU2_DATA_DIR") + if data_dir: + os.environ["TAU2_DATA_DIR"] = data_dir + + env_constructor = registry.get_env_constructor(self._domain) + env = env_constructor(solo_mode=False) + tau2_tools: list[Tau2Tool] = env.get_tools() + if add_thinking: + tau2_tools.append(Tau2Tool(_think_tool_fn)) + + tools = [_make_pydantic_tool(t) for t in tau2_tools] + system_prompt = env.get_policy() + + model_name = agent_llm_cfg.get("model", "openai:default") + base_url = agent_llm_cfg.get("base_url") + api_key = agent_llm_cfg.get("api_key", "unused") + + if base_url: + model: Any = OpenAIChatModel( + model_name.replace("openai:", ""), + provider=OpenAIProvider(base_url=base_url, api_key=api_key), + ) + else: + model = model_name + + self._openai_chat_model = OpenAIChatModel + self._openai_provider = OpenAIProvider + self._agent = Agent(model, system_prompt=system_prompt, tools=tools) + logger.info( + "Tau2Agent initialized (domain=%s, tools=%d, model=%s)", + self._domain, + len(tools), + model_name, + ) + + def _resolve_model(self, metadata: dict[str, Any]) -> Any: + base_url = metadata.get("inference_base_url") + if not base_url: + return self._agent.model + model_name = metadata.get("inference_model", "default") + api_key = metadata.get("inference_api_key", "unused") + return self._openai_chat_model( + model_name, + provider=self._openai_provider(base_url=base_url, api_key=api_key), + ) + + async def run( + self, + request: AgentRequest, + *, + emitter: EventEmitter, + ) -> AgentResponse: + from pydantic_ai.messages import ( + ModelRequest, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + ) + from pydantic_ai.messages import ModelResponse as PAModelResponse + + message_history: list[ModelRequest | PAModelResponse] = [] + for msg in request.history: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "user": + message_history.append( + ModelRequest(parts=[UserPromptPart(content=content or "")]) + ) + elif role == "assistant": + tool_calls = msg.get("tool_calls") + if tool_calls: + parts = [] + for tc in tool_calls: + fn = tc.get("function", tc) + parts.append( + ToolCallPart( + tool_name=fn.get("name", ""), + args=fn.get("arguments", ""), + tool_call_id=tc.get("id", ""), + ) + ) + message_history.append(PAModelResponse(parts=parts)) + elif content: + message_history.append( + PAModelResponse(parts=[TextPart(content=content)]) + ) + elif role == "tool": + tool_call_id = msg.get("tool_call_id", "") + message_history.append( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name=tool_call_id, + content=content or "", + tool_call_id=tool_call_id, + ) + ] + ) + ) + + try: + result = await self._agent.run( + request.message, + message_history=message_history, + model=self._resolve_model(request.metadata), + ) + except Exception as exc: + logger.error("Tau2Agent turn failed: %s", exc) + await emitter.emit_delta(f"Agent error: {exc}") + return AgentResponse( + summary=f"Agent error: {exc}", metadata={"tool_calls": []} + ) + + final_text = str(result.output) if result.output else "" + tool_calls: list[dict[str, Any]] = [] + for msg in result.new_messages(): + if not hasattr(msg, "parts"): + continue + for part in msg.parts: + kind = getattr(part, "part_kind", "") + if kind == "tool-call": + name = getattr(part, "tool_name", "") + args = getattr(part, "args", "") + if isinstance(args, dict): + args = json.dumps(args) + await emitter.emit_tool_call(name=name, args=str(args)) + tool_calls.append({"name": name, "arguments": args}) + elif kind == "tool-return": + name = getattr(part, "tool_name", "") + content = str(getattr(part, "content", "")) + await emitter.emit_tool_result(name=name, result=content) + + if final_text: + await emitter.emit_delta(final_text) + + return AgentResponse(summary=final_text[:200], metadata={"tool_calls": tool_calls}) diff --git a/examples/experimental/agent_service/tau2/workflow.py b/examples/experimental/agent_service/tau2/workflow.py new file mode 100644 index 0000000000..4549633cf3 --- /dev/null +++ b/examples/experimental/agent_service/tau2/workflow.py @@ -0,0 +1,249 @@ +"""Tau2 workflow using the experimental agent service.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any + +from openai import AsyncOpenAI + +from areal.api.workflow_api import RolloutWorkflow +from areal.infra import workflow_context +from areal.utils import logging, stats_tracker + +if TYPE_CHECKING: + from areal.api.engine_api import InferenceEngine + from areal.experimental.agent_service.controller.controller import AgentController + from areal.experimental.openai.types import InteractionWithTokenLogpReward + +logger = logging.getLogger("Tau2AgentServiceWorkflow") + + +def _extract_response_text(response: dict[str, Any]) -> str: + parts: list[str] = [] + for item in response.get("output", []): + if item.get("type") == "message": + for block in item.get("content", []): + if block.get("type") == "output_text": + parts.append(block.get("text", "")) + return "\n".join(parts).strip() + + +def _extract_completion_text(completion: Any) -> str: + choice = completion.choices[0] + message = getattr(choice, "message", None) + content = getattr(message, "content", "") if message is not None else "" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text" and item.get("text"): + parts.append(str(item["text"])) + else: + text = getattr(item, "text", None) + if text: + parts.append(str(text)) + return "\n".join(parts).strip() + return str(content).strip() + + +class Tau2AgentServiceWorkflow(RolloutWorkflow): + def __init__( + self, + agent_controller: AgentController, + inference_gateway_addr: str, + inference_admin_api_key: str, + inference_model: str = "", + econfig: dict[str, Any] | None = None, + gen_args: dict[str, Any] | None = None, + timeout: float = 600.0, + max_turns: int = 10, + discount: float = 1.0, + export_style: str = "individual", + ) -> None: + from examples.tau2.utils import Tau2EnvConfig + + self.agent_controller = agent_controller + self.inference_gateway_addr = inference_gateway_addr.rstrip("/") + self.inference_admin_api_key = inference_admin_api_key + self.inference_model = inference_model + self.econfig = ( + Tau2EnvConfig(**econfig) + if isinstance(econfig, dict) + else (econfig or Tau2EnvConfig()) + ) + self.gen_args = gen_args or {} + self.timeout = timeout + self.max_turns = max_turns + self.discount = discount + self.export_style = export_style + + async def _run_dialog( + self, + data: dict[str, Any], + agent_session_id: str, + ) -> float: + from tau2.data_model.message import AssistantMessage, UserMessage + from tau2.data_model.simulation import SimulationRun, TerminationReason + from tau2.evaluator.evaluator import EvaluationType, evaluate_simulation + + from examples.tau2.agent import _get_task + from examples.tau2.utils import Tau2EnvConfig + + econfig = self.econfig + if "econfig" in data: + econfig = Tau2EnvConfig(**data["econfig"]) + + task = _get_task( + domain=econfig.domain, + task_id=data["task_id"], + split=data.get("split", "train"), + ) + first_user_message = str(data.get("prompt") or task.user_scenario).strip() + if not first_user_message: + raise ValueError("data.prompt or task.user_scenario is required") + + user_client = None + if not econfig.solo_mode: + if not econfig.user_llm_base_url: + raise ValueError( + "econfig.user_llm_base_url is required when solo_mode is false" + ) + user_client = AsyncOpenAI( + base_url=econfig.user_llm_base_url, + api_key="dummy", + max_retries=3, + timeout=120.0, + ) + + tau2_messages: list[UserMessage | AssistantMessage] = [] + chat_history: list[dict[str, str]] = [ + {"role": "user", "content": first_user_message} + ] + next_user_message = first_user_message + + for turn_idx in range(self.max_turns): + response = await self.agent_controller.step(next_user_message, agent_session_id) + agent_text = _extract_response_text(response) or "(no response)" + + tau2_messages.append( + UserMessage( + role="user", + content=next_user_message, + turn_idx=len(tau2_messages), + ) + ) + tau2_messages.append( + AssistantMessage( + role="assistant", + content=agent_text, + turn_idx=len(tau2_messages), + ) + ) + + if turn_idx + 1 >= self.max_turns or user_client is None: + break + + chat_history.append({"role": "assistant", "content": agent_text}) + completion = await user_client.chat.completions.create( + model=econfig.user_llm or "dummy", + messages=[ + { + "role": "system", + "content": ( + "You are simulating the tau2 user described below. " + "Respond with the user's next message only, in one turn, " + "based on the conversation so far.\n\n" + f"User scenario:\n{task.user_scenario}" + ), + }, + *chat_history, + ], + **(econfig.user_llm_args or {}), + ) + next_user_message = _extract_completion_text(completion) + if not next_user_message: + break + chat_history.append({"role": "user", "content": next_user_message}) + + simulation = SimulationRun( + id=f"agent-svc-{task.id}", + task_id=task.id, + messages=tau2_messages, + start_time="", + end_time="", + duration=0.0, + termination_reason=TerminationReason.USER_STOP, + ) + reward_info = evaluate_simulation( + simulation=simulation, + task=task, + evaluation_type=EvaluationType.ALL, + solo_mode=econfig.solo_mode, + domain=econfig.domain, + ) + return float(reward_info.reward) + + async def arun_episode( + self, + engine: InferenceEngine, + data: dict[str, Any], + ) -> dict[str, InteractionWithTokenLogpReward] | None: + del engine + task_id = str(data.get("task_id") or workflow_context.get().task_id) + session = await self.agent_controller.start_session( + task_id=task_id, + inference_gateway_addr=self.inference_gateway_addr, + inference_admin_api_key=self.inference_admin_api_key, + inference_model=self.inference_model, + ) + + trajectory_id: int | None = None + finished = False + try: + reward = await asyncio.wait_for( + self._run_dialog(data, session["session_id"]), + timeout=self.timeout, + ) + reward_result = await self.agent_controller.set_reward( + reward, + session["session_id"], + ) + raw_trajectory_id = reward_result.get("trajectory_id") + trajectory_id = ( + int(raw_trajectory_id) if raw_trajectory_id is not None else None + ) + finished = True + except Exception: + logger.warning( + "Tau2 agent-service task failed. This trajectory will be rejected." + ) + if not finished: + try: + await self.agent_controller.set_reward(0.0, session["session_id"]) + except Exception: + logger.warning( + "Failed to finish session %s after workflow failure", + session["session_id"], + ) + raise + + interactions = await self.agent_controller.export_trajectory( + session["session_id"], + trajectory_id=trajectory_id, + discount=self.discount, + style=self.export_style, + ) + if not interactions: + logger.warning( + "Session %s has no interactions, trajectory will be rejected.", + session["session_id"], + ) + return None + + last_id = next(reversed(interactions)) + last_reward = interactions[last_id].reward + stats_tracker.get(workflow_context.stat_scope()).scalar(reward=last_reward) + return interactions From a3f12e49ea32a01c5c5832fa26675c01a6d6e7a7 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:19:13 +0000 Subject: [PATCH 08/12] feat(examples): add tau2 agent service rollout driver Add a runnable tau2 rollout entrypoint and config for validating the experimental agent service against inference-backed trajectories. Key changes: - add tau2 rollout config - add direct rollout driver - support async batch gather validation control Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .../agent_service/tau2/config.yaml | 134 ++++++++++++ .../agent_service/tau2/run_rollout.py | 206 ++++++++++++++++++ 2 files changed, 340 insertions(+) create mode 100644 examples/experimental/agent_service/tau2/config.yaml create mode 100644 examples/experimental/agent_service/tau2/run_rollout.py diff --git a/examples/experimental/agent_service/tau2/config.yaml b/examples/experimental/agent_service/tau2/config.yaml new file mode 100644 index 0000000000..a3db01562d --- /dev/null +++ b/examples/experimental/agent_service/tau2/config.yaml @@ -0,0 +1,134 @@ +experiment_name: tau2-agent-service-rollout +trial_name: 1.7b-telecom + +seed: 1 +enable_offload: false +total_train_epochs: 1 +total_train_steps: null +tokenizer_path: ${model_path} + +model_path: Qwen/Qwen3-1.7B + +cluster: + n_nodes: 1 + n_gpus_per_node: 2 + fileroot: /path/to/experiments + name_resolve: + type: nfs + nfs_record_root: /path/to/name_resolve + +scheduler: + type: local + +gconfig: + n_samples: 1 + min_new_tokens: 0 + max_new_tokens: 8192 + max_tokens: 16384 + greedy: false + temperature: 1.0 + +rollout: + _version: v2 + backend: "sglang:d2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 32 + queue_size: null + consumer_batch_size: 8 + max_head_offpolicyness: 1000000000 + enable_rollout_tracing: true + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: false + admin_api_key: rollout-admin + openai: + mode: inline + tool_call_parser: qwen25 + reasoning_parser: qwen3 + engine_max_tokens: ${gconfig.max_tokens} + export_style: individual + turn_discount: 1.0 + admin_api_key: rollout-admin + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + cpu: 2 + mem: 16 + cmd: python3 -m areal.experimental.inference_service.guard + env_vars: + AREAL_PROXY_WARN_ONCE: "1" + +agent_service: + agent_cls_path: examples.experimental.agent_service.tau2.agent.Tau2Agent + admin_api_key: areal-agent-admin + num_pairs: 1 + setup_timeout: 120.0 + health_poll_interval: 5.0 + drain_timeout: 30.0 + log_level: info + env: {} + +econfig: + domain: telecom + max_steps: 50 + add_thinking_tool: false + solo_mode: false + user_llm_base_url: http://localhost:8000/v1/ + user_llm: openai/self-hosted-Qwen2.5-72B + user_llm_args: + temperature: 0.0 + max_completion_tokens: 512 + turn_discount: 1.0 + invalid_format_penalty: 0.1 + +sglang: + model_path: ${model_path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: bfloat16 + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +train_dataset: + batch_size: 8 + shuffle: true + pin_memory: true + num_workers: 4 + path: tau2/train + type: rl + max_length: 1024 + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled diff --git a/examples/experimental/agent_service/tau2/run_rollout.py b/examples/experimental/agent_service/tau2/run_rollout.py new file mode 100644 index 0000000000..247637fe21 --- /dev/null +++ b/examples/experimental/agent_service/tau2/run_rollout.py @@ -0,0 +1,206 @@ +"""Direct rollout driver for the tau2 agent-service workflow.""" + +from __future__ import annotations + +import asyncio +import os +import sys +import warnings +from copy import deepcopy +from dataclasses import asdict, dataclass, field +from typing import Any + +from datasets import Dataset + +from areal.api.alloc_mode import ModelAllocation +from areal.api.cli_args import ( + AgentConfig, + BaseExperimentConfig, + GenerationHyperparameters, + InferenceEngineConfig, + SGLangConfig, + TrainDatasetConfig, + load_expr_config, +) +from areal.experimental.agent_service.controller import AgentController +from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, +) +from areal.utils import logging + +logger = logging.getLogger("Tau2AgentServiceRollout") + + +def get_tau2_dataset(domain: str, type: str = "rl", split: str = "train") -> Dataset: + from tau2.registry import registry + + assert type == "rl", "Only RL dataset is supported for now" + splits_loader_fn = registry.get_task_splits_loader(domain) + if splits_loader_fn is None: + raise ValueError(f"No task splits loader found for domain {domain}") + splits = splits_loader_fn() + if split not in splits: + raise ValueError( + f"Split {split} not found for domain {domain}, available splits: {list(splits.keys())}" + ) + task_ids = splits[split] + dataset_items = [{"task_id": task_id, "split": split} for task_id in task_ids] + if len(dataset_items) < 128: + original_items = dataset_items.copy() + while len(dataset_items) < 128: + dataset_items.extend(original_items) + return Dataset.from_list(dataset_items) + + +@dataclass +class Tau2AgentServiceRolloutConfig(BaseExperimentConfig): + gconfig: GenerationHyperparameters = field(default_factory=GenerationHyperparameters) + rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig) + model_path: str = "" + econfig: dict[str, Any] = field(default_factory=dict) + agent_service: AgentConfig = field( + default_factory=lambda: AgentConfig( + agent_cls_path="examples.experimental.agent_service.tau2.agent.Tau2Agent" + ) + ) + sglang: SGLangConfig = field(default_factory=SGLangConfig) + train_dataset: TrainDatasetConfig = field(default_factory=TrainDatasetConfig) + + +async def _run_rollouts( + workflow: Any, + controller: RolloutControllerV2, + dataloader: Any, + *, + max_batches: int | None = None, +) -> None: + batch_count = 0 + for batch_idx, batch in enumerate(dataloader): + if max_batches is not None and batch_count >= max_batches: + break + + keys = list(batch.keys()) + batch_size = len(batch[keys[0]]) + data_rows = [{k: batch[k][i] for k in keys} for i in range(batch_size)] + + results = await asyncio.gather( + *(workflow.arun_episode(controller, row) for row in data_rows) + ) + + rewards: list[float] = [] + trajectories = 0 + for result in results: + if not result: + continue + trajectories += 1 + last_id = next(reversed(result)) + reward = result[last_id].reward + if reward is not None: + rewards.append(float(reward)) + + avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 + logger.info( + "Batch %d: n_trajs=%d, rewards=%s, avg_reward=%.4f", + batch_idx, + trajectories, + rewards, + avg_reward, + ) + batch_count += 1 + + logger.info("Rollout complete (%d batches)", batch_count) + + +def main(argv: list[str]) -> None: + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + + config, _ = load_expr_config(argv, Tau2AgentServiceRolloutConfig) + rollout_cfg = deepcopy(config.rollout) + rollout_cfg.model = config.model_path + + from examples.experimental.agent_service.tau2.workflow import ( + Tau2AgentServiceWorkflow, + ) + from examples.tau2.utils import Tau2EnvConfig + + econfig = ( + Tau2EnvConfig(**config.econfig) + if isinstance(config.econfig, dict) + else config.econfig + ) + train_dataset = get_tau2_dataset( + domain=econfig.domain, + type=config.train_dataset.type, + split=config.train_dataset.path.split("/")[-1], + ) + + from torch.utils.data import DataLoader + + dataloader = DataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size, + shuffle=config.train_dataset.shuffle, + num_workers=0, + ) + + from areal.infra.scheduler.local import LocalScheduler + from areal.infra.scheduler.slurm import SlurmScheduler + + if config.scheduler.type == "local": + scheduler = LocalScheduler(exp_config=config) + elif config.scheduler.type == "slurm": + scheduler = SlurmScheduler(exp_config=config) + else: + raise NotImplementedError(f"Unknown scheduler type: {config.scheduler.type}") + + rollout_alloc = ModelAllocation.from_str(config.rollout.backend, name="rollout") + if rollout_alloc.backend == "sglang": + server_args = asdict(config.sglang) + elif rollout_alloc.backend == "vllm": + server_args = asdict(config.vllm) + else: + raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") + + rollout_controller = RolloutControllerV2(config=rollout_cfg, scheduler=scheduler) + rollout_controller.initialize(role="rollout", server_args=server_args) + + agent_controller = AgentController(config=config.agent_service, scheduler=scheduler) + agent_controller.initialize() + + workflow = Tau2AgentServiceWorkflow( + agent_controller=agent_controller, + inference_gateway_addr=rollout_controller.proxy_gateway_addr, + inference_admin_api_key=rollout_cfg.admin_api_key, + inference_model=config.model_path, + econfig=asdict(econfig), + gen_args={ + "temperature": config.gconfig.temperature, + "max_completion_tokens": config.gconfig.max_new_tokens, + }, + timeout=600.0, + discount=rollout_cfg.openai.turn_discount if rollout_cfg.openai else 1.0, + export_style=( + rollout_cfg.openai.export_style if rollout_cfg.openai else "individual" + ), + ) + + max_batches_env = os.environ.get("AREAL_MAX_BATCHES") + max_batches = int(max_batches_env) if max_batches_env is not None else None + + try: + asyncio.run( + _run_rollouts( + workflow, + rollout_controller, + dataloader, + max_batches=max_batches, + ) + ) + finally: + agent_controller.destroy() + rollout_controller.destroy() + scheduler.delete_workers(None) + + +if __name__ == "__main__": + main(sys.argv[1:]) From 5578b7022a9ecec1f3092b5ba67bf1ae5ef1c675 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:19:20 +0000 Subject: [PATCH 09/12] test(examples): cover tau2 agent service workflow Add focused workflow coverage for exported tau2 trajectories and reward reporting in the experimental agent-service example. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .../agent_service/test_tau2_workflow.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/experimental/agent_service/test_tau2_workflow.py diff --git a/tests/experimental/agent_service/test_tau2_workflow.py b/tests/experimental/agent_service/test_tau2_workflow.py new file mode 100644 index 0000000000..3fa8cb94b5 --- /dev/null +++ b/tests/experimental/agent_service/test_tau2_workflow.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from examples.experimental.agent_service.tau2.workflow import Tau2AgentServiceWorkflow + + +@pytest.mark.asyncio +@patch("examples.experimental.agent_service.tau2.workflow.workflow_context") +@patch("examples.experimental.agent_service.tau2.workflow.stats_tracker") +async def test_arun_episode_returns_exported_interactions( + mock_stats_tracker, + mock_workflow_context, +): + controller = MagicMock() + controller.start_session = AsyncMock( + return_value={ + "session_id": "agent-sess-1", + "inference_session_id": "inf-sess-1", + "api_key": "sess-key", + } + ) + controller.set_reward = AsyncMock(return_value={"trajectory_id": 3}) + exported = {"last": SimpleNamespace(reward=1.0)} + controller.export_trajectory = AsyncMock(return_value=exported) + + workflow = Tau2AgentServiceWorkflow( + agent_controller=controller, + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_model="Qwen/Test", + econfig={"domain": "telecom", "solo_mode": True}, + timeout=10.0, + ) + workflow._run_dialog = AsyncMock(return_value=1.0) + mock_scope = object() + mock_workflow_context.get.return_value = SimpleNamespace(task_id="task-from-context") + mock_workflow_context.stat_scope.return_value = mock_scope + mock_stats = MagicMock() + mock_stats_tracker.get.return_value = mock_stats + + result = await workflow.arun_episode(engine=object(), data={"task_id": "task-1"}) + + assert result is exported + controller.start_session.assert_awaited_once_with( + task_id="task-1", + inference_gateway_addr="http://inference", + inference_admin_api_key="rollout-admin", + inference_model="Qwen/Test", + ) + workflow._run_dialog.assert_awaited_once_with({"task_id": "task-1"}, "agent-sess-1") + controller.set_reward.assert_awaited_once_with(1.0, "agent-sess-1") + controller.export_trajectory.assert_awaited_once_with( + "agent-sess-1", + trajectory_id=3, + discount=1.0, + style="individual", + ) + mock_stats.scalar.assert_called_once_with(reward=1.0) From c550f9518da6fef081e7ce0b426783bd556460da Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:29:05 +0000 Subject: [PATCH 10/12] style(service): format agent service docs and controller Normalize wrapped lines in the agent service README and controller so pre-commit leaves the service module clean. Co-authored-by: Sisyphus --- areal/experimental/agent_service/README.md | 29 +++++++++---------- .../agent_service/controller/controller.py | 8 ++--- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/areal/experimental/agent_service/README.md b/areal/experimental/agent_service/README.md index 5543de01a9..c496f61205 100644 --- a/areal/experimental/agent_service/README.md +++ b/areal/experimental/agent_service/README.md @@ -6,8 +6,7 @@ The Agent Service provides **agent-level** capabilities on top of AReaL's model- proxy. It exposes complete agent sessions — multi-turn conversations with tool use, memory, and pluggable agent frameworks — via independent HTTP microservices. It also includes an `AgentController` that can launch the stack through Guard processes and -bridge agent conversations to the experimental inference service for RL data -collection. +bridge agent conversations to the experimental inference service for RL data collection. ## Architecture @@ -141,23 +140,23 @@ class EventEmitter(Protocol): `AgentController` is the integration point used by the examples and rollout workflows. It manages the agent-service stack and exposes async helpers for RL/inference flows: -| Method | Description | -| ------ | ----------- | -| `initialize()` | Launch Guards, Router, Worker+DataProxy pairs, Gateway, and the health monitor | -| `destroy()` | Tear down the full stack in reverse order | -| `scale_up(count)` | Add Worker+DataProxy pairs | -| `scale_down(count)` | Unregister, drain, and remove pairs | -| `start_session(...)` | Grant inference capacity and create an RL session bound to an agent session | -| `step(input, session_id, metadata=None)` | Send a turn through the agent-service Gateway `POST /v1/responses` | -| `set_reward(reward, session_id, interaction_id=None)` | Forward the final reward to the inference service | -| `export_trajectory(session_id, ...)` | Export serialized interactions from the inference service | +| Method | Description | +| ----------------------------------------------------- | ------------------------------------------------------------------------------ | +| `initialize()` | Launch Guards, Router, Worker+DataProxy pairs, Gateway, and the health monitor | +| `destroy()` | Tear down the full stack in reverse order | +| `scale_up(count)` | Add Worker+DataProxy pairs | +| `scale_down(count)` | Unregister, drain, and remove pairs | +| `start_session(...)` | Grant inference capacity and create an RL session bound to an agent session | +| `step(input, session_id, metadata=None)` | Send a turn through the agent-service Gateway `POST /v1/responses` | +| `set_reward(reward, session_id, interaction_id=None)` | Forward the final reward to the inference service | +| `export_trajectory(session_id, ...)` | Export serialized interactions from the inference service | Typical rollout flow: 1. `start_session()` to create the agent/inference session pair. -2. `step()` for each user turn. -3. `set_reward()` when the episode completes. -4. `export_trajectory()` to retrieve interactions for training. +1. `step()` for each user turn. +1. `set_reward()` when the episode completes. +1. `export_trajectory()` to retrieve interactions for training. ## Multi-turn Conversation Flow diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py index 633034c38b..19f662e331 100644 --- a/areal/experimental/agent_service/controller/controller.py +++ b/areal/experimental/agent_service/controller/controller.py @@ -475,13 +475,13 @@ async def step( metadata: dict[str, Any] | None = None, ) -> dict[str, Any]: if not self._gateway_addr: - raise RuntimeError("step() requires the agent-service gateway to be running") + raise RuntimeError( + "step() requires the agent-service gateway to be running" + ) session = self._resolve_session(session_id) input_items = ( - [{"type": "message", "content": input}] - if isinstance(input, str) - else input + [{"type": "message", "content": input}] if isinstance(input, str) else input ) merged_metadata: dict[str, Any] = { From f64719f518a635021ce4d10415d6c61b3fd001b8 Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:29:05 +0000 Subject: [PATCH 11/12] style(examples): format experimental agent service examples Apply mdformat and ruff formatting to the experimental agent service examples so the example package stays pre-commit clean. Co-authored-by: Sisyphus --- examples/experimental/agent_service/README.md | 36 +++++++++---------- .../experimental/agent_service/tau2/agent.py | 4 ++- .../agent_service/tau2/run_rollout.py | 4 ++- .../agent_service/tau2/workflow.py | 4 ++- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/examples/experimental/agent_service/README.md b/examples/experimental/agent_service/README.md index 967fc27fde..74787b53f3 100644 --- a/examples/experimental/agent_service/README.md +++ b/examples/experimental/agent_service/README.md @@ -2,16 +2,16 @@ ## Overview -This directory contains experimental examples built on top of AReaL's agent service. -The examples are grouped by scenario: +This directory contains experimental examples built on top of AReaL's agent service. The +examples are grouped by scenario: - `claude/` — a standalone Claude Agent SDK service demo -- `tau2/` — a tau2 customer-service rollout example that combines the agent service - with the experimental inference service +- `tau2/` — a tau2 customer-service rollout example that combines the agent service with + the experimental inference service -The agent service exposes complete agent sessions through Router, DataProxy, Worker, -and Gateway microservices, and can be paired with the experimental inference service -for RL data collection. +The agent service exposes complete agent sessions through Router, DataProxy, Worker, and +Gateway microservices, and can be paired with the experimental inference service for RL +data collection. ## Example 1: Claude Agent SDK Service @@ -32,8 +32,8 @@ python examples/experimental/agent_service/claude/run_agent_service.py --num-pai ``` The script creates a `LocalScheduler`, launches Guard workers, then forks Router, -Worker+DataProxy pairs, and Gateway. An interactive prompt lets you chat with the -Claude agent through `POST /v1/responses`. +Worker+DataProxy pairs, and Gateway. An interactive prompt lets you chat with the Claude +agent through `POST /v1/responses`. Files: @@ -71,8 +71,8 @@ python examples/experimental/agent_service/tau2/run_rollout.py \ ### What it does 1. Starts the experimental inference service with `RolloutControllerV2`. -2. Starts the experimental agent service with `AgentController`. -3. For each tau2 task, the workflow: +1. Starts the experimental agent service with `AgentController`. +1. For each tau2 task, the workflow: - calls `AgentController.start_session()` (which grants capacity and starts the RL session), - drives the tau2 conversation through `AgentController.step()`, @@ -81,11 +81,11 @@ python examples/experimental/agent_service/tau2/run_rollout.py \ ### Files -| File | Description | -| --- | --- | -| `claude/agent.py` | Claude Agent SDK example agent | +| File | Description | +| ----------------------------- | ------------------------------------------------- | +| `claude/agent.py` | Claude Agent SDK example agent | | `claude/run_agent_service.py` | Interactive launcher for the Claude agent service | -| `tau2/agent.py` | Tau2 agent-service worker agent | -| `tau2/workflow.py` | Tau2 rollout workflow using async controller APIs | -| `tau2/run_rollout.py` | Direct rollout driver for the tau2 workflow | -| `tau2/config.yaml` | Example config for the tau2 rollout driver | +| `tau2/agent.py` | Tau2 agent-service worker agent | +| `tau2/workflow.py` | Tau2 rollout workflow using async controller APIs | +| `tau2/run_rollout.py` | Direct rollout driver for the tau2 workflow | +| `tau2/config.yaml` | Example config for the tau2 rollout driver | diff --git a/examples/experimental/agent_service/tau2/agent.py b/examples/experimental/agent_service/tau2/agent.py index 404cec5c16..f5c37ab440 100644 --- a/examples/experimental/agent_service/tau2/agent.py +++ b/examples/experimental/agent_service/tau2/agent.py @@ -220,4 +220,6 @@ async def run( if final_text: await emitter.emit_delta(final_text) - return AgentResponse(summary=final_text[:200], metadata={"tool_calls": tool_calls}) + return AgentResponse( + summary=final_text[:200], metadata={"tool_calls": tool_calls} + ) diff --git a/examples/experimental/agent_service/tau2/run_rollout.py b/examples/experimental/agent_service/tau2/run_rollout.py index 247637fe21..919569438b 100644 --- a/examples/experimental/agent_service/tau2/run_rollout.py +++ b/examples/experimental/agent_service/tau2/run_rollout.py @@ -54,7 +54,9 @@ def get_tau2_dataset(domain: str, type: str = "rl", split: str = "train") -> Dat @dataclass class Tau2AgentServiceRolloutConfig(BaseExperimentConfig): - gconfig: GenerationHyperparameters = field(default_factory=GenerationHyperparameters) + gconfig: GenerationHyperparameters = field( + default_factory=GenerationHyperparameters + ) rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig) model_path: str = "" econfig: dict[str, Any] = field(default_factory=dict) diff --git a/examples/experimental/agent_service/tau2/workflow.py b/examples/experimental/agent_service/tau2/workflow.py index 4549633cf3..990babb9ea 100644 --- a/examples/experimental/agent_service/tau2/workflow.py +++ b/examples/experimental/agent_service/tau2/workflow.py @@ -125,7 +125,9 @@ async def _run_dialog( next_user_message = first_user_message for turn_idx in range(self.max_turns): - response = await self.agent_controller.step(next_user_message, agent_session_id) + response = await self.agent_controller.step( + next_user_message, agent_session_id + ) agent_text = _extract_response_text(response) or "(no response)" tau2_messages.append( From 858c2aee60e2f0bc1099f031344165448c6cc11c Mon Sep 17 00:00:00 2001 From: nuzant Date: Sun, 26 Apr 2026 03:29:05 +0000 Subject: [PATCH 12/12] style: format tau2 workflow test Reflow the tau2 workflow test to match Ruff formatting after the repository-wide pre-commit run. Co-authored-by: Sisyphus --- tests/experimental/agent_service/test_tau2_workflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/experimental/agent_service/test_tau2_workflow.py b/tests/experimental/agent_service/test_tau2_workflow.py index 3fa8cb94b5..a3daad4d6f 100644 --- a/tests/experimental/agent_service/test_tau2_workflow.py +++ b/tests/experimental/agent_service/test_tau2_workflow.py @@ -39,7 +39,9 @@ async def test_arun_episode_returns_exported_interactions( ) workflow._run_dialog = AsyncMock(return_value=1.0) mock_scope = object() - mock_workflow_context.get.return_value = SimpleNamespace(task_id="task-from-context") + mock_workflow_context.get.return_value = SimpleNamespace( + task_id="task-from-context" + ) mock_workflow_context.stat_scope.return_value = mock_scope mock_stats = MagicMock() mock_stats_tracker.get.return_value = mock_stats