diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index edad30414d..08dc9f7dd8 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1963,6 +1963,68 @@ def __post_init__(self): raise ValueError("admin_api_key must not be empty or whitespace-only") +@dataclass +class AgentConfig: + """Configuration for the experimental agent service controller.""" + + agent_cls_path: str = field( + default="", + metadata={ + "help": "Fully-qualified import path for the AgentRunnable implementation." + }, + ) + admin_api_key: str = field( + default="areal-agent-admin", + metadata={"help": "Shared admin API key for agent-service inter-service auth."}, + ) + num_pairs: int = field( + default=1, + metadata={"help": "Number of Worker+DataProxy pairs to launch on initialize."}, + ) + setup_timeout: float = field( + default=120.0, + metadata={ + "help": "Timeout in seconds waiting for each service to become healthy." + }, + ) + health_poll_interval: float = field( + default=5.0, + metadata={ + "help": "Seconds between pair health polls; 0 disables health monitoring." + }, + ) + drain_timeout: float = field( + default=30.0, + metadata={ + "help": "Seconds to wait for active sessions to drain before force-killing a pair." + }, + ) + log_level: str = field( + default="info", + metadata={"help": "Log level for spawned agent-service micro-services."}, + ) + env: dict[str, str] = field( + default_factory=dict, + metadata={ + "help": "Extra environment variables passed to all forked child processes." + }, + ) + + def __post_init__(self) -> None: + if not self.agent_cls_path: + raise ValueError("agent_cls_path must be a non-empty import path") + if self.num_pairs < 0: + raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") + if self.setup_timeout <= 0: + raise ValueError( + f"setup_timeout must be positive, got {self.setup_timeout}" + ) + if self.drain_timeout < 0: + raise ValueError( + f"drain_timeout must be non-negative, got {self.drain_timeout}" + ) + + @dataclass class InferenceEngineConfig: """Configuration for inference servers, including offpolicyness control.""" @@ -2081,6 +2143,79 @@ class InferenceEngineConfig: }, ) + # v2 controller options + _version: str = field( + default="v1", + metadata={ + "help": "Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2.", + "choices": ["v1", "v2"], + }, + ) + model: str = field( + default="default", + metadata={"help": "Model name exposed through the inference-service gateway."}, + ) + routing_strategy: str = field( + default="round_robin", + metadata={"help": "Routing strategy for the inference-service router."}, + ) + poll_interval: float = field( + default=5.0, + metadata={ + "help": "Health-poll interval in seconds for the inference-service router." + }, + ) + set_reward_finish_timeout: float = field( + default=0.0, + metadata={ + "help": "Timeout in seconds to wait for additional reward updates before finalizing a session." + }, + ) + session_timeout_seconds: float = field( + default=3600.0, + metadata={ + "help": "Timeout in seconds before an inactive inference-service session is considered stale and cleaned up." + }, + ) + stale_session_cleanup_interval_seconds: float = field( + default=60.0, + metadata={ + "help": "Polling interval in seconds for stale-session cleanup in inference-service data proxies." + }, + ) + stale_session_dump_path: str = field( + default="", + metadata={ + "help": "Optional directory path where stale-session trajectory dumps are written before cleanup." + }, + ) + log_level: str = field( + default="info", + metadata={"help": "Log level for inference-service micro-services."}, + ) + admin_api_key: str = field( + default="areal-admin-key", + metadata={ + "help": "Admin API key used by the inference-service gateway, router, and data proxies." + }, + ) + api_url: str | None = field( + default=None, + metadata={ + "help": "External OpenAI-compatible base URL for inference-service external model mode." + }, + ) + provider_api_key: str | None = field( + default=None, + metadata={"help": "API key for the external OpenAI-compatible provider."}, + ) + n_gpus_per_node: int | None = field( + default=None, + metadata={ + "help": "GPUs per physical node for multinode inference-service launch." + }, + ) + def __post_init__(self): """Validate scheduling_spec length.""" if len(self.scheduling_spec) not in (1, 2): @@ -2088,6 +2223,35 @@ def __post_init__(self): f"scheduling_spec must contain 1 or 2 SchedulingSpec, " f"got {len(self.scheduling_spec)}" ) + if self._version not in ("v1", "v2"): + raise ValueError( + f"_version must be either 'v1' or 'v2', got '{self._version}'" + ) + if self.n_gpus_per_node is not None and self.n_gpus_per_node < 1: + raise ValueError( + f"n_gpus_per_node must be >= 1, got {self.n_gpus_per_node}" + ) + if self.session_timeout_seconds <= 0: + raise ValueError( + "session_timeout_seconds must be positive, " + f"got {self.session_timeout_seconds}" + ) + if self.stale_session_cleanup_interval_seconds <= 0: + raise ValueError( + "stale_session_cleanup_interval_seconds must be positive, " + f"got {self.stale_session_cleanup_interval_seconds}" + ) + if not self.admin_api_key or not self.admin_api_key.strip(): + raise ValueError("admin_api_key must not be empty or whitespace-only") + if ( + self._version == "v2" + and self.openai is not None + and self.openai.admin_api_key != "areal-admin-key" + ): + logger.warning( + "rollout.openai.admin_api_key is ignored by rollout controller v2; " + "use rollout.admin_api_key instead." + ) @dataclass diff --git a/areal/experimental/agent_service/README.md b/areal/experimental/agent_service/README.md index f3dc0f839f..30f76bb8ec 100644 --- a/areal/experimental/agent_service/README.md +++ b/areal/experimental/agent_service/README.md @@ -159,9 +159,8 @@ areal/experimental/agent_service/ ├── protocol.py # Gateway protocol frame types ├── types.py # AgentRequest, AgentResponse, EventEmitter, AgentRunnable ├── controller/ -│ ├── __init__.py # AgentServiceController, AgentServiceControllerConfig -│ ├── config.py # AgentServiceControllerConfig dataclass -│ └── controller.py # AgentServiceController orchestrator +│ ├── __init__.py # AgentController export +│ └── controller.py # AgentController orchestrator ├── guard/ │ ├── __init__.py # Module docstring │ ├── __main__.py # python -m areal.experimental.agent_service.guard diff --git a/areal/experimental/agent_service/__init__.py b/areal/experimental/agent_service/__init__.py index 3858d5c133..2c964e550f 100644 --- a/areal/experimental/agent_service/__init__.py +++ b/areal/experimental/agent_service/__init__.py @@ -8,7 +8,7 @@ Submodules ---------- -- ``controller`` — :class:`AgentServiceController` orchestrator +- ``controller`` — :class:`AgentController` orchestrator - ``gateway`` — public HTTP/WebSocket entry point - ``router`` — session-affine routing - ``data_proxy`` — stateful session proxy diff --git a/areal/experimental/agent_service/controller/__init__.py b/areal/experimental/agent_service/controller/__init__.py index 3150205885..a677fc8b08 100644 --- a/areal/experimental/agent_service/controller/__init__.py +++ b/areal/experimental/agent_service/controller/__init__.py @@ -2,10 +2,11 @@ """Agent Service Controller — orchestrator for agent micro-services.""" -from .config import AgentServiceControllerConfig -from .controller import AgentServiceController +from areal.api.cli_args import AgentConfig + +from .controller import AgentController __all__ = [ - "AgentServiceController", - "AgentServiceControllerConfig", + "AgentController", + "AgentConfig", ] diff --git a/areal/experimental/agent_service/controller/config.py b/areal/experimental/agent_service/controller/config.py deleted file mode 100644 index c316d58227..0000000000 --- a/areal/experimental/agent_service/controller/config.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Configuration for the AgentServiceController.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - -from ..auth import DEFAULT_ADMIN_API_KEY - - -@dataclass -class AgentServiceControllerConfig: - """Unified configuration for AgentServiceController. - - Consolidates settings for the guard, router, gateway, worker, and - data proxy micro-services launched by the controller. - """ - - # -- Agent class ------------------------------------------------------- - agent_cls_path: str = "" - """Fully-qualified import path for the ``AgentRunnable`` implementation - (e.g. ``examples.agent_service.agent.Tau2Agent``).""" - - # -- Authentication ---------------------------------------------------- - admin_api_key: str = DEFAULT_ADMIN_API_KEY - """Shared admin API key for inter-service Bearer auth.""" - - # -- Scaling ----------------------------------------------------------- - num_pairs: int = 1 - """Number of Worker+DataProxy pairs to launch on initialize.""" - - # -- Timeouts ---------------------------------------------------------- - setup_timeout: float = 120.0 - """Timeout (seconds) waiting for each service to become healthy.""" - - health_poll_interval: float = 5.0 - """Seconds between health polls for crash detection (0 = disabled).""" - - drain_timeout: float = 30.0 - """Seconds to wait for active sessions to drain before force-killing a pair.""" - - # -- Log level --------------------------------------------------------- - log_level: str = "info" - """Log level for spawned micro-services.""" - - # -- Environment ------------------------------------------------------- - env: dict[str, str] = field(default_factory=dict) - """Extra environment variables to pass to all forked child processes.""" - - def __post_init__(self) -> None: - if not self.agent_cls_path: - raise ValueError("agent_cls_path must be a non-empty import path") - if self.num_pairs < 0: - raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") - if self.setup_timeout <= 0: - raise ValueError( - f"setup_timeout must be positive, got {self.setup_timeout}" - ) - if self.drain_timeout < 0: - raise ValueError( - f"drain_timeout must be non-negative, got {self.drain_timeout}" - ) diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py index 21b12851bb..e7465e52be 100644 --- a/areal/experimental/agent_service/controller/controller.py +++ b/areal/experimental/agent_service/controller/controller.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -"""AgentServiceController — orchestrates agent service micro-services via Guards. +"""AgentController — orchestrates agent service micro-services via Guards. Mirrors the architecture of -:class:`~areal.experimental.inference_service.controller.controller.GatewayInferenceController`: +:class:`~areal.experimental.inference_service.controller.controller.RolloutControllerV2`: Guard workers are created via the Scheduler, then the controller forks Router, Worker+DataProxy pairs, and Gateway onto them via HTTP API. @@ -12,7 +12,7 @@ from areal.infra.scheduler.local import LocalScheduler scheduler = LocalScheduler(...) - controller = AgentServiceController(config, scheduler) + controller = AgentController(config, scheduler) controller.initialize() # ... run traffic ... controller.scale_up(2) # add 2 Worker+DataProxy pairs @@ -32,16 +32,14 @@ import requests -from areal.experimental.agent_service.controller.config import ( - AgentServiceControllerConfig, -) +from areal.api.cli_args import AgentConfig from areal.utils import logging from areal.utils.network import format_hostport if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker -logger = logging.getLogger("AgentServiceController") +logger = logging.getLogger("AgentController") _GUARD_ROLE = "agent-guard" _UNREGISTER_RETRIES = 3 @@ -60,7 +58,7 @@ class _WorkerPair: worker_addr: str -class AgentServiceController: +class AgentController: """Orchestrator for the Agent Service micro-service stack. Parameters @@ -73,7 +71,7 @@ class AgentServiceController: def __init__( self, - config: AgentServiceControllerConfig, + config: AgentConfig, scheduler: Scheduler, ) -> None: self.config = config diff --git a/areal/experimental/inference_service/controller/__init__.py b/areal/experimental/inference_service/controller/__init__.py index e122baa6db..9ac57d0d68 100644 --- a/areal/experimental/inference_service/controller/__init__.py +++ b/areal/experimental/inference_service/controller/__init__.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) -__all__ = ["GatewayControllerConfig", "GatewayInferenceController"] +__all__ = ["InferenceEngineConfig", "RolloutControllerV2"] diff --git a/areal/experimental/inference_service/controller/config.py b/areal/experimental/inference_service/controller/config.py deleted file mode 100644 index 4d45b1391b..0000000000 --- a/areal/experimental/inference_service/controller/config.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Configuration for the GatewayInferenceController.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - - -@dataclass -class GatewayControllerConfig: - """Unified configuration for GatewayInferenceController. - - Consolidates settings for the gateway, router, data proxy services, - and the WorkflowExecutor / staleness management. - """ - - # -- Model / tokenizer ------------------------------------------------- - tokenizer_path: str = "" - model_path: str = "" - model: str = "default" - - # -- Routing ----------------------------------------------------------- - routing_strategy: str = "round_robin" - poll_interval: float = 5.0 # router health-poll interval (seconds) - - # -- HTTP timeouts ----------------------------------------------------- - request_timeout: float = 120.0 # per-request timeout (seconds) - setup_timeout: float = 300.0 # timeout waiting for services to start - set_reward_finish_timeout: float = 0.0 - - # -- Log level for gateway micro-services ------------------------------ - log_level: str = "info" - - # -- WorkflowExecutor / staleness -------------------------------------- - consumer_batch_size: int = 16 - max_concurrent_rollouts: int | None = None - max_head_offpolicyness: int = 0 - queue_size: int | None = None - enable_rollout_tracing: bool = False - - # -- Trajectory dump --------------------------------------------------- - fileroot: str | None = None - experiment_name: str | None = None - trial_name: str | None = None - check_trajectory_format: bool = False - dump_to_file: bool = False - - # -- Scheduler / allocation (passed through from trainer) -------------- - backend: str = "sglang:d1" - scheduling_spec: tuple = field(default_factory=tuple) - pause_grace_period: float = 0.5 - n_gpus_per_node: int | None = None # GPUs per physical node; None = single-node - - # -- Admin / workflow -------------------------------------------------- - admin_api_key: str | None = None - turn_discount: float = 1.0 - export_style: str = "individual" - tool_call_parser: str = "qwen" - reasoning_parser: str = "qwen3" - engine_max_tokens: int | None = None - chat_template_type: str = "hf" - - # -- External model API ------------------------------------------------ - api_url: str | None = None - provider_api_key: str | None = None diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index 7cc546a554..2381cf1d69 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""GatewayInferenceController — parallel implementation to RolloutController. +"""RolloutControllerV2 — parallel implementation to RolloutController. Routes inference and pause/continue traffic through the gateway HTTP stack (Gateway → Router → Data Proxy → inference backend). @@ -11,6 +11,7 @@ from __future__ import annotations import asyncio +import copy import os import sys import threading @@ -28,14 +29,12 @@ if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker +from areal.api.cli_args import InferenceEngineConfig from areal.api.io_struct import LocalInfServerInfo -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) from areal.utils import logging from areal.utils.network import format_hostport -logger = logging.getLogger("GatewayInferenceController") +logger = logging.getLogger("RolloutControllerV2") _MAX_COMPLETED_ONLINE_RESULTS = 1024 @@ -48,7 +47,7 @@ class _OnlineWaiter: class _DummyDataLoader: """Minimal dataloader that yields a single batch of empty dicts. - Used by :meth:`GatewayInferenceController.prepare_batch` when + Used by :meth:`RolloutControllerV2.prepare_batch` when ``dataloader`` is ``None`` (online-agent mode). """ @@ -59,7 +58,7 @@ def __iter__(self): yield [{} for _ in range(self.batch_size)] -class GatewayInferenceController: +class RolloutControllerV2: """Inference controller that routes everything through the gateway HTTP stack. This is a **parallel** implementation to ``RolloutController`` (NOT a @@ -79,15 +78,15 @@ class GatewayInferenceController: def __init__( self, - config: GatewayControllerConfig, + config: InferenceEngineConfig, scheduler: Scheduler, ) -> None: - if config.admin_api_key is None: + if config.admin_api_key is None or not config.admin_api_key.strip(): raise ValueError( - "GatewayControllerConfig.admin_api_key must be set (not None)" + "InferenceEngineConfig.admin_api_key must be set (not None or empty)" ) if not config.model: - raise ValueError("GatewayControllerConfig.model must not be empty") + raise ValueError("InferenceEngineConfig.model must not be empty") self.config = config self.scheduler = scheduler @@ -199,7 +198,6 @@ def initialize( self._register_data_proxies_in_router() # Create WorkflowExecutor directly (no intermediate engine) - from areal.api.cli_args import InferenceEngineConfig from areal.infra.remote_inf_engine import RemoteInfEngine from areal.infra.workflow_executor import WorkflowExecutor @@ -222,7 +220,7 @@ def initialize( max_staleness=self.config.max_head_offpolicyness, ) - logger.info("GatewayInferenceController initialized (role=%s)", role) + logger.info("RolloutControllerV2 initialized (role=%s)", role) if self.config.model: self.register_model( @@ -237,6 +235,18 @@ def initialize( self.config.model, ) + def offload(self) -> None: + """Offload hook placeholder for trainer compatibility.""" + logger.warning( + "RolloutControllerV2.offload is not implemented and will be skipped" + ) + + def onload(self, tags: list[str] | None = None) -> None: + """Onload hook placeholder for trainer compatibility.""" + logger.warning( + "RolloutControllerV2.onload is not implemented and will be skipped" + ) + async def _async_initialize( self, server_args: dict[str, Any] | None, @@ -266,6 +276,7 @@ async def _async_initialize( cfg = self.config admin_api_key = self.config.admin_api_key + openai_cfg = self._openai_config if self.external_mode: dp_size = 1 @@ -350,10 +361,9 @@ async def _async_initialize( if inf_backend == "sglang": from areal.api.cli_args import SGLangConfig - sglang_config = SGLangConfig( - model_path=cfg.model_path or cfg.tokenizer_path, - ) + sglang_config = SGLangConfig() if server_args: + sglang_config = copy.deepcopy(sglang_config) for k, v in server_args.items(): if hasattr(sglang_config, k): setattr(sglang_config, k, v) @@ -386,17 +396,19 @@ def _build_launch_cmd( elif inf_backend == "vllm": from areal.api.cli_args import vLLMConfig - vllm_config = vLLMConfig(model=cfg.model_path or cfg.tokenizer_path) - for k, v in (server_args or {}).items(): - if hasattr(vllm_config, k): - setattr(vllm_config, k, v) - else: - logger.warning( - "vLLMConfig has no attribute %r, ignoring " - "server_args entry (value=%r)", - k, - v, - ) + vllm_config = vLLMConfig() + if server_args: + vllm_config = copy.deepcopy(vllm_config) + for k, v in server_args.items(): + if hasattr(vllm_config, k): + setattr(vllm_config, k, v) + else: + logger.warning( + "vLLMConfig has no attribute %r, ignoring " + "server_args entry (value=%r)", + k, + v, + ) def _build_launch_cmd( host: str | None, @@ -580,19 +592,25 @@ def _build_launch_cmd( str(cfg.request_timeout), "--set-reward-finish-timeout", str(cfg.set_reward_finish_timeout), + "--session-timeout-seconds", + str(cfg.session_timeout_seconds), + "--stale-session-cleanup-interval-seconds", + str(cfg.stale_session_cleanup_interval_seconds), + "--stale-session-dump-path", + cfg.stale_session_dump_path, "--callback-server-addr", f"http://{self.callback_addr}", "--tool-call-parser", - cfg.tool_call_parser, + openai_cfg.tool_call_parser, "--reasoning-parser", - cfg.reasoning_parser, + openai_cfg.reasoning_parser, "--chat-template-type", - cfg.chat_template_type, + openai_cfg.chat_template_type, ] - if cfg.engine_max_tokens is not None: + if openai_cfg.engine_max_tokens is not None: data_proxy_base_cmd += [ "--engine-max-tokens", - str(cfg.engine_max_tokens), + str(openai_cfg.engine_max_tokens), ] for group_idx in range(dp_size): @@ -951,9 +969,7 @@ def get_version(self) -> int: def get_capacity(self) -> int: if self.staleness_manager is None: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") return self.staleness_manager.get_capacity() # -- Submit / Wait / Batch --------------------------------------------- @@ -1044,9 +1060,7 @@ def rollout_batch( A list of trajectory dicts (one per completed rollout). """ if not self._gateway_addr: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") if data is None: if batch_size is None: raise ValueError( @@ -1115,9 +1129,7 @@ def prepare_batch( A list of trajectory dicts (matching ``RolloutController`` API). """ if not self._gateway_addr: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") if dataloader is None: if batch_size is None: raise ValueError( @@ -1321,9 +1333,7 @@ def staleness_manager(self): @property def workflow_executor(self): if self._workflow_executor is None: - raise RuntimeError( - "GatewayInferenceController.initialize() must be called first" - ) + raise RuntimeError("RolloutControllerV2.initialize() must be called first") return self._workflow_executor @property @@ -1353,8 +1363,8 @@ def _wrap_agent(self, agent: Any): "Gateway address is unavailable; initialize the controller first" ) - openai_cfg = self.config - admin_api_key = openai_cfg.admin_api_key + openai_cfg = self._openai_config + admin_api_key = self.config.admin_api_key turn_discount = openai_cfg.turn_discount export_style = openai_cfg.export_style @@ -1440,13 +1450,13 @@ def _resolve_workflow( # (c) Reject RolloutWorkflow classes and instances if isinstance(agent, type) and issubclass(agent, RolloutWorkflow): raise TypeError( - "GatewayInferenceController only accepts agent classes or instances with a " + "RolloutControllerV2 only accepts agent classes or instances with a " "run() method or None for online mode; direct RolloutWorkflow " "classes are not supported" ) if isinstance(agent, RolloutWorkflow): raise TypeError( - "GatewayInferenceController only accepts agent classes or instances with a " + "RolloutControllerV2 only accepts agent classes or instances with a " "run() method or None for online mode; direct RolloutWorkflow " "instances are not supported" ) @@ -1490,6 +1500,14 @@ def _resolve_should_accept_fn( return cast(Callable[[dict[str, Any]], bool], func) raise TypeError(f"Invalid should_accept_fn type: {type(should_accept_fn)}") + @property + def _openai_config(self): + from areal.api.cli_args import OpenAIProxyConfig + + return self.config.openai or OpenAIProxyConfig( + admin_api_key=self.config.admin_api_key + ) + # -- Internal HTTP helpers --------------------------------------------- def _fork_on_guard( diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index 95f0571770..3bbfeb8b97 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from areal.api.engine_api import InferenceEngine from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.openai.types import InteractionWithTokenLogpReward @@ -29,7 +29,7 @@ class InferenceServiceWorkflow(RolloutWorkflow): def __init__( self, - controller: GatewayInferenceController, + controller: RolloutControllerV2, agent: Any | None = None, gateway_addr: str = "", admin_api_key: str = "areal-admin-key", diff --git a/areal/experimental/inference_service/data_proxy/__main__.py b/areal/experimental/inference_service/data_proxy/__main__.py index ef2ea87894..dfd50f4b80 100644 --- a/areal/experimental/inference_service/data_proxy/__main__.py +++ b/areal/experimental/inference_service/data_proxy/__main__.py @@ -44,6 +44,20 @@ def main(): type=float, default=0.0, ) + parser.add_argument( + "--session-timeout-seconds", + type=float, + default=3600.0, + ) + parser.add_argument( + "--stale-session-cleanup-interval-seconds", + type=float, + default=60.0, + ) + parser.add_argument( + "--stale-session-dump-path", + default="", + ) parser.add_argument( "--admin-api-key", default="areal-admin-key", @@ -88,6 +102,9 @@ def main(): log_level=args.log_level, request_timeout=args.request_timeout, set_reward_finish_timeout=args.set_reward_finish_timeout, + session_timeout_seconds=args.session_timeout_seconds, + stale_session_cleanup_interval_seconds=args.stale_session_cleanup_interval_seconds, + stale_session_dump_path=args.stale_session_dump_path, admin_api_key=args.admin_api_key, callback_server_addr=args.callback_server_addr, serving_addr=format_hostport(serving_host, args.port), diff --git a/areal/experimental/inference_service/data_proxy/app.py b/areal/experimental/inference_service/data_proxy/app.py index 6ab04c8d32..18c675bb87 100644 --- a/areal/experimental/inference_service/data_proxy/app.py +++ b/areal/experimental/inference_service/data_proxy/app.py @@ -224,6 +224,22 @@ async def _ready_trajectory_loop(app: FastAPI) -> None: await asyncio.sleep(0.1) +async def _cleanup_stale_sessions(app: FastAPI) -> None: + store: SessionStore = app.state.session_store + config: DataProxyConfig = app.state.config + await store.cleanup_stale( + timeout_seconds=config.session_timeout_seconds, + stale_session_dump_path=config.stale_session_dump_path, + serving_addr=config.serving_addr, + ) + + +async def _stale_session_cleanup_loop(app: FastAPI) -> None: + while True: + await _cleanup_stale_sessions(app) + await asyncio.sleep(app.state.config.stale_session_cleanup_interval_seconds) + + def create_app(config: DataProxyConfig) -> FastAPI: """Factory that creates the FastAPI app with lifespan-managed resources.""" @@ -257,14 +273,20 @@ async def lifespan(app: FastAPI): app.state.areal_client = areal_client ready_task = asyncio.create_task(_ready_trajectory_loop(app)) + cleanup_task = asyncio.create_task(_stale_session_cleanup_loop(app)) try: yield finally: ready_task.cancel() + cleanup_task.cancel() try: await ready_task except asyncio.CancelledError: pass + try: + await cleanup_task + except asyncio.CancelledError: + pass logger.info("Data proxy shutting down") app = FastAPI(title="AReaL Data Proxy", lifespan=lifespan) diff --git a/areal/experimental/inference_service/data_proxy/config.py b/areal/experimental/inference_service/data_proxy/config.py index 44b2f6fe91..5aed9c58e8 100644 --- a/areal/experimental/inference_service/data_proxy/config.py +++ b/areal/experimental/inference_service/data_proxy/config.py @@ -2,6 +2,10 @@ from dataclasses import dataclass +from areal.experimental.inference_service.data_proxy.session import ( + SESSION_TIMEOUT_SECONDS, +) + @dataclass class DataProxyConfig: @@ -13,6 +17,9 @@ class DataProxyConfig: log_level: str = "info" request_timeout: float = 120.0 # seconds per SGLang call set_reward_finish_timeout: float = 0.0 + session_timeout_seconds: float = SESSION_TIMEOUT_SECONDS + stale_session_cleanup_interval_seconds: float = 60.0 + stale_session_dump_path: str = "" max_resubmit_retries: int = 20 # max abort/resubmit cycles before giving up resubmit_wait: float = 0.5 # seconds between is_paused polls admin_api_key: str = "areal-admin-key" # admin key for authentication diff --git a/areal/experimental/inference_service/data_proxy/session.py b/areal/experimental/inference_service/data_proxy/session.py index 5646104fb8..6ed3dc39ed 100644 --- a/areal/experimental/inference_service/data_proxy/session.py +++ b/areal/experimental/inference_service/data_proxy/session.py @@ -4,18 +4,26 @@ from __future__ import annotations +import asyncio +import json import secrets import threading import time import uuid from collections import OrderedDict from dataclasses import dataclass +from pathlib import Path from typing import Any +import aiofiles +import aiofiles.os from pydantic import BaseModel from areal.experimental.openai.cache import InteractionCache +from areal.experimental.openai.proxy.server import serialize_interactions from areal.experimental.openai.types import InteractionWithTokenLogpReward +from areal.infra.rpc import rtensor as rtensor_storage +from areal.infra.rpc.rtensor import RTensor # Session timeout for cleanup (1 hour) SESSION_TIMEOUT_SECONDS = 3600 @@ -338,6 +346,127 @@ def export_trajectory( ) return ready.trajectory_id, interactions + def snapshot_cleanup_exports( + self, + discount: float, + style: str, + ) -> list[tuple[int | None, dict[str, InteractionWithTokenLogpReward]]]: + """Snapshot exportable trajectory data for stale-session cleanup. + + This does not mutate session membership in ``SessionStore``. It may apply + reward discounting in-place on cached interactions the first time it runs, + matching the existing export behavior. + """ + + exports: list[tuple[int | None, dict[str, InteractionWithTokenLogpReward]]] = [] + + def _export_cache( + cache: InteractionCache, + ) -> dict[str, InteractionWithTokenLogpReward]: + reward_discount = None if cache._apply_reward_discount_called else discount + return cache.export_interactions( + style=style, + reward_discount=reward_discount, + ) + + with self._lock: + for ready in self._ready_trajectories.values(): + interactions = _export_cache(ready.completions) + if interactions: + exports.append((ready.trajectory_id, interactions)) + + if len(self._active_completions) != 0: + interactions = _export_cache(self._active_completions) + if interactions: + exports.append((None, interactions)) + + return exports + + +def _prepare_interactions_for_stale_dump( + interactions: dict[str, Any], + serving_addr: str, +) -> dict[str, list[Any]]: + def _localize_for_stale_dump(obj: Any) -> Any: + if isinstance(obj, RTensor): + if obj.data.is_meta and ( + not obj.shard.node_addr or obj.shard.node_addr == serving_addr + ): + return rtensor_storage.fetch(obj.shard.shard_id) + return obj + + if isinstance(obj, dict): + return {k: _localize_for_stale_dump(v) for k, v in obj.items()} + + if isinstance(obj, list): + return [_localize_for_stale_dump(item) for item in obj] + + if isinstance(obj, tuple): + return tuple(_localize_for_stale_dump(item) for item in obj) + + return obj + + shards_by_node: dict[str, list[Any]] = {} + + for item in interactions.values(): + if item.has_tensor_data: + tensor_dict = item.to_tensor_dict() + shard_map = RTensor.collect_shards(tensor_dict) + if shard_map: + # Stale-session dumps must be self-contained. If an interaction cache + # already contains RTensors (for example because /export_trajectories + # remotized it earlier), serialize_interactions() would otherwise dump + # RTensor metadata that depends on external shard storage. + # Localize everything back to plain torch.Tensor values before writing + # the dump file so the dumped JSON remains reloadable on its own. + localized = RTensor.localize(_localize_for_stale_dump(tensor_dict)) + item._cache = localized + else: + item._cache = tensor_dict + + for node_addr, shard_ids in shard_map.items(): + shards_by_node.setdefault(node_addr, []).extend(shard_ids) + + return shards_by_node + + +async def _dump_stale_session_exports( + session_id: str, + exports: list[tuple[int | None, dict[str, Any]]], + stale_session_dump_path: str | None, + serving_addr: str, +) -> None: + if not stale_session_dump_path: + return + + dump_dir = Path(stale_session_dump_path) + await aiofiles.os.makedirs(dump_dir, exist_ok=True) + + for trajectory_id, interactions in exports: + # Prepare a self-contained dump payload and remember any RTensor shards + # that backed the interactions before localization. + shards_by_node = _prepare_interactions_for_stale_dump( + interactions, serving_addr + ) + serialized = serialize_interactions(interactions) + trajectory_suffix = ( + f"trajectory-{trajectory_id}" if trajectory_id is not None else "active" + ) + dump_path = dump_dir / f"{session_id}-{trajectory_suffix}.json" + serialized_json = await asyncio.to_thread(json.dumps, serialized) + async with aiofiles.open(dump_path, "w") as dump_file: + await dump_file.write(serialized_json) + + # Once the dump has been written with inline tensors, the original RTensor + # shards are no longer needed for stale-session recovery. Remove them to + # avoid leaking local/remote shard storage. + for node_addr, shard_ids in shards_by_node.items(): + if not node_addr or node_addr == serving_addr: + for shard_id in shard_ids: + rtensor_storage.remove(shard_id) + continue + await RTensor.clear_node(node_addr, shard_ids) + # ============================================================================= # Session Store @@ -445,19 +574,43 @@ def _remove_api_keys_for_session(self, session_id: str) -> None: if api_key: self._api_key_to_session.pop(api_key, None) - def cleanup_stale(self, timeout_seconds: float = SESSION_TIMEOUT_SECONDS) -> None: + async def cleanup_stale( + self, + timeout_seconds: float = SESSION_TIMEOUT_SECONDS, + stale_session_dump_path: str | None = None, + serving_addr: str = "", + ) -> None: with self._lock: - stale_sessions: list[str] = [] - for sid, session in self._sessions.items(): + stale_sessions: list[tuple[str, SessionData]] = [] + for sid, session in list(self._sessions.items()): if not session.is_stale(timeout_seconds): continue - if session.has_ready_trajectories: - continue - stale_sessions.append(sid) - - for sid in stale_sessions: self._sessions.pop(sid, None) self._remove_api_keys_for_session(sid) + stale_sessions.append((sid, session)) + + for sid, session in stale_sessions: + exports = session.snapshot_cleanup_exports( + discount=1.0, + style="individual", + ) + await _dump_stale_session_exports( + sid, + exports, + stale_session_dump_path=stale_session_dump_path, + serving_addr=serving_addr, + ) + + def stale_session_ids( + self, + timeout_seconds: float = SESSION_TIMEOUT_SECONDS, + ) -> list[str]: + with self._lock: + return [ + sid + for sid, session in self._sessions.items() + if session.is_stale(timeout_seconds) + ] def finalize_rewarded_trajectories( self, diff --git a/areal/infra/data_service/controller/controller.py b/areal/infra/data_service/controller/controller.py index 5bfd2ae7ab..67188b8150 100644 --- a/areal/infra/data_service/controller/controller.py +++ b/areal/infra/data_service/controller/controller.py @@ -5,7 +5,7 @@ Manages the full lifecycle: create RPCGuard workers → fork DataWorkers, Router, Gateway → register datasets → serve batches → shutdown. -Follows the same patterns as ``GatewayInferenceController``. +Follows the same patterns as ``RolloutControllerV2``. """ from __future__ import annotations @@ -33,7 +33,7 @@ class DataController: """Controller for the distributed data loading service. - API follows ``TrainController`` / ``GatewayInferenceController`` patterns: + API follows ``TrainController`` / ``RolloutControllerV2`` patterns: ``__init__(config, scheduler)`` then ``initialize(role, ...)``. """ diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 978a6c18de..61fd4f994a 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -6,7 +6,7 @@ import os from collections.abc import Callable from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader @@ -35,6 +35,9 @@ vLLMConfig, ) from areal.engine import RemoteSGLangEngine, RemotevLLMEngine +from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, +) from areal.infra import ( LocalScheduler, RayScheduler, @@ -1008,7 +1011,12 @@ def _init_rollout( return engine # Single-controller mode - no engine instantiation needed - controller = engine_cls.as_controller(config, self.scheduler) + if config._version == "v2": + controller = RolloutControllerV2( + config=config, scheduler=cast(Scheduler, self.scheduler) + ) + else: + controller = engine_cls.as_controller(config, self.scheduler) init_kwargs = dict( role="rollout", server_args=server_args, diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 3d4f83bcd3..8b1130af8c 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -110,9 +110,9 @@ "AgentRouter": "light_purple", "AgentWorker": "light_purple", "AgentDataProxy": "light_purple", - "AgentServiceController": "light_purple", + "AgentController": "light_purple", # Inference service - white (orchestration) - "GatewayInferenceController": "white", + "RolloutControllerV2": "white", "InferenceDataProxy": "white", "InferenceInfBridge": "white", "InferenceRouter": "white", diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 8ea5263c78..eb1ba30fa0 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -71,6 +71,7 @@ For detailed examples, see the experiment configurations in the `examples/` dire ### Others +- [Agent Configuration](section-agent) - [ArchonEngine Configuration](section-archon-engine) - [ArchonFP8 Configuration](section-archon-fp8) - [DPO Configuration](section-dpo) @@ -527,30 +528,43 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | -| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | -| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | -| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| Parameter | Type | Default | Description | +| ---------------------------------------- | ---------------------------------------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | +| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | +| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | +| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | +| `model` | string | `"default"` | Model name exposed through the inference-service gateway. | +| `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | +| `poll_interval` | float | `5.0` | Health-poll interval in seconds for the inference-service router. | +| `set_reward_finish_timeout` | float | `0.0` | Timeout in seconds to wait for additional reward updates before finalizing a session. | +| `session_timeout_seconds` | float | `3600.0` | Timeout in seconds before an inactive inference-service session is considered stale and cleaned up. | +| `stale_session_cleanup_interval_seconds` | float | `60.0` | Polling interval in seconds for stale-session cleanup in inference-service data proxies. | +| `stale_session_dump_path` | string | `""` | Optional directory path where stale-session trajectory dumps are written before cleanup. | +| `log_level` | string | `"info"` | Log level for inference-service micro-services. | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by the inference-service gateway, router, and data proxies. | +| `api_url` | string \| None | `None` | External OpenAI-compatible base URL for inference-service external model mode. | +| `provider_api_key` | string \| None | `None` | API key for the external OpenAI-compatible provider. | +| `n_gpus_per_node` | integer \| None | `None` | GPUs per physical node for multinode inference-service launch. | (section-sg-lang)= @@ -856,6 +870,23 @@ Configuration for Weights & Biases experiment tracking. | `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | +(section-agent)= + +## Agent Configuration + +Configuration for the experimental agent service controller. + +| Parameter | Type | Default | Description | +| ---------------------- | ------- | --------------------- | ------------------------------------------------------------------------- | +| `agent_cls_path` | string | `""` | Fully-qualified import path for the AgentRunnable implementation. | +| `admin_api_key` | string | `"areal-agent-admin"` | Shared admin API key for agent-service inter-service auth. | +| `num_pairs` | integer | `1` | Number of Worker+DataProxy pairs to launch on initialize. | +| `setup_timeout` | float | `120.0` | Timeout in seconds waiting for each service to become healthy. | +| `health_poll_interval` | float | `5.0` | Seconds between pair health polls; 0 disables health monitoring. | +| `drain_timeout` | float | `30.0` | Seconds to wait for active sessions to drain before force-killing a pair. | +| `log_level` | string | `"info"` | Log level for spawned agent-service micro-services. | +| `env` | `dict` | **Required** | Extra environment variables passed to all forked child processes. | + (section-archon-engine)= ## ArchonEngine Configuration diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index b1ca97aa3e..24fc1f3064 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -69,6 +69,7 @@ python3 train.py --config path/to/config.yaml actor.lr=1e-4 seed=42 ### Others +- [Agent Configuration](section-agent) - [ArchonEngine Configuration](section-archon-engine) - [ArchonFP8 Configuration](section-archon-fp8) - [DPO Configuration](section-dpo) @@ -525,30 +526,43 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | -| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | -| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | -| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| Parameter | Type | Default | Description | +| ---------------------------------------- | ---------------------------------------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `fileroot` | string \| None | `None` | Root directory for logs and trajectory dumps. | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `tokenizer_path` | string | `""` | Path to tokenizer for trajectory text decoding. | +| `dump_to_file` | boolean | `False` | Whether to dump the trajectories to files under fileroot. | +| `setup_timeout` | float | `300.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | `tuple` | **Required** | inference engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the RolloutController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'sglang:d4', 'vllm:d2t4'. Required. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | +| `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | +| `model` | string | `"default"` | Model name exposed through the inference-service gateway. | +| `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | +| `poll_interval` | float | `5.0` | Health-poll interval in seconds for the inference-service router. | +| `set_reward_finish_timeout` | float | `0.0` | Timeout in seconds to wait for additional reward updates before finalizing a session. | +| `session_timeout_seconds` | float | `3600.0` | Timeout in seconds before an inactive inference-service session is considered stale and cleaned up. | +| `stale_session_cleanup_interval_seconds` | float | `60.0` | Polling interval in seconds for stale-session cleanup in inference-service data proxies. | +| `stale_session_dump_path` | string | `""` | Optional directory path where stale-session trajectory dumps are written before cleanup. | +| `log_level` | string | `"info"` | Log level for inference-service micro-services. | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by the inference-service gateway, router, and data proxies. | +| `api_url` | string \| None | `None` | External OpenAI-compatible base URL for inference-service external model mode. | +| `provider_api_key` | string \| None | `None` | API key for the external OpenAI-compatible provider. | +| `n_gpus_per_node` | integer \| None | `None` | GPUs per physical node for multinode inference-service launch. | (section-sg-lang)= @@ -854,6 +868,23 @@ Configuration for Weights & Biases experiment tracking. | `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | +(section-agent)= + +## Agent Configuration + +Configuration for the experimental agent service controller. + +| Parameter | Type | Default | Description | +| ---------------------- | ------- | --------------------- | ------------------------------------------------------------------------- | +| `agent_cls_path` | string | `""` | Fully-qualified import path for the AgentRunnable implementation. | +| `admin_api_key` | string | `"areal-agent-admin"` | Shared admin API key for agent-service inter-service auth. | +| `num_pairs` | integer | `1` | Number of Worker+DataProxy pairs to launch on initialize. | +| `setup_timeout` | float | `120.0` | Timeout in seconds waiting for each service to become healthy. | +| `health_poll_interval` | float | `5.0` | Seconds between pair health polls; 0 disables health monitoring. | +| `drain_timeout` | float | `30.0` | Seconds to wait for active sessions to drain before force-killing a pair. | +| `log_level` | string | `"info"` | Log level for spawned agent-service micro-services. | +| `env` | `dict` | **Required** | Extra environment variables passed to all forked child processes. | + (section-archon-engine)= ## ArchonEngine Configuration diff --git a/examples/agent_service/README.md b/examples/agent_service/README.md index 563064b4e5..1c7e8bc839 100644 --- a/examples/agent_service/README.md +++ b/examples/agent_service/README.md @@ -86,14 +86,14 @@ Turn 2: Client → Gateway → Router (same DataProxy) → DataProxy → Worker ```python from areal.experimental.agent_service.controller import ( - AgentServiceController, - AgentServiceControllerConfig, + AgentController, ) +from areal.api.cli_args import AgentConfig from areal.infra.scheduler.local import LocalScheduler scheduler = LocalScheduler(experiment_name="demo", trial_name="run0", gpu_devices=[]) -ctrl = AgentServiceController( - config=AgentServiceControllerConfig( +ctrl = AgentController( + config=AgentConfig( agent_cls_path="examples.agent_service.agent.ClaudeAgent", num_pairs=2, ), diff --git a/examples/agent_service/run_agent_service.py b/examples/agent_service/run_agent_service.py index e96f83f501..80850dc4ca 100644 --- a/examples/agent_service/run_agent_service.py +++ b/examples/agent_service/run_agent_service.py @@ -21,10 +21,8 @@ import httpx -from areal.experimental.agent_service.controller import ( - AgentServiceController, - AgentServiceControllerConfig, -) +from areal.api.cli_args import AgentConfig +from areal.experimental.agent_service.controller import AgentController async def _wait_healthy(url: str, timeout: float = 60.0) -> None: @@ -103,12 +101,12 @@ def main() -> None: gpu_devices=[], ) - ctrl_config = AgentServiceControllerConfig( + ctrl_config = AgentConfig( agent_cls_path="examples.agent_service.agent.ClaudeAgent", admin_api_key=args.admin_api_key, num_pairs=args.num_pairs, ) - ctrl = AgentServiceController(config=ctrl_config, scheduler=scheduler) + ctrl = AgentController(config=ctrl_config, scheduler=scheduler) try: print(f"Initializing with {args.num_pairs} pair(s) ...") diff --git a/examples/experimental/inference_service/README.md b/examples/experimental/inference_service/README.md index 1b7c0af21f..a039974b7b 100644 --- a/examples/experimental/inference_service/README.md +++ b/examples/experimental/inference_service/README.md @@ -1,7 +1,7 @@ # AReaL Inference Service Examples This directory contains two examples that use the AReaL Inference Service -(`GatewayInferenceController`) — an experimental rollout backend that exposes an +(`RolloutControllerV2`) — an experimental rollout backend that exposes an OpenAI-compatible proxy gateway so any external agent runtime can submit chat requests and receive RL training data. diff --git a/examples/experimental/inference_service/human_in_the_loop_demo.py b/examples/experimental/inference_service/human_in_the_loop_demo.py index 96f08c00a8..a6e833512d 100644 --- a/examples/experimental/inference_service/human_in_the_loop_demo.py +++ b/examples/experimental/inference_service/human_in_the_loop_demo.py @@ -283,7 +283,8 @@ def cleanup(signum=None, frame=None): str(config_yaml), f"actor.path={args.actor_path}", f"rollout.backend={args.inference_backend}:d1", - f"rollout.openai.admin_api_key={args.admin_key}", + f"rollout.admin_api_key={args.admin_key}", + "rollout._version=v2", f"rollout.request_timeout={args.request_timeout}", ] if args.api_url: diff --git a/examples/experimental/inference_service/online_rollout.py b/examples/experimental/inference_service/online_rollout.py index 106f8e86d0..dd0bfd1bdb 100644 --- a/examples/experimental/inference_service/online_rollout.py +++ b/examples/experimental/inference_service/online_rollout.py @@ -4,6 +4,7 @@ import argparse import sys +from copy import deepcopy from dataclasses import asdict from pathlib import Path @@ -20,11 +21,8 @@ def main(args: list[str]) -> None: ext_args, remaining = parser.parse_known_args(args) from areal.api.cli_args import PPOConfig, load_expr_config - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.utils import logging from areal.utils.environ import is_single_controller @@ -54,26 +52,8 @@ def main(args: list[str]) -> None: is_external = ext_args.api_url is not None - ctrl_config = GatewayControllerConfig( - tokenizer_path=config.tokenizer_path, - model_path=config.actor.path, - consumer_batch_size=config.rollout.consumer_batch_size, - max_concurrent_rollouts=config.rollout.max_concurrent_rollouts, - max_head_offpolicyness=config.rollout.max_head_offpolicyness, - queue_size=config.rollout.queue_size, - enable_rollout_tracing=config.rollout.enable_rollout_tracing, - fileroot=config.rollout.fileroot, - experiment_name=config.rollout.experiment_name, - trial_name=config.rollout.trial_name, - dump_to_file=False, - backend=config.rollout.backend, - scheduling_spec=config.rollout.scheduling_spec, - setup_timeout=config.rollout.setup_timeout, - request_timeout=config.rollout.request_timeout, - admin_api_key=openai_cfg.admin_api_key, - turn_discount=openai_cfg.turn_discount, - export_style=openai_cfg.export_style, - ) + ctrl_config = deepcopy(config.rollout) + ctrl_config.dump_to_file = False if ext_args.model: ctrl_config.model = ext_args.model if is_external: @@ -91,7 +71,7 @@ def main(args: list[str]) -> None: else: raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") - ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) + ctrl = RolloutControllerV2(config=ctrl_config, scheduler=scheduler) try: ctrl.initialize( role="rollout", diff --git a/examples/experimental/inference_service/online_rollout.yaml b/examples/experimental/inference_service/online_rollout.yaml index 307afdfa32..923a16da9e 100644 --- a/examples/experimental/inference_service/online_rollout.yaml +++ b/examples/experimental/inference_service/online_rollout.yaml @@ -19,6 +19,7 @@ scheduler: type: local rollout: + _version: v2 backend: "sglang:d1" experiment_name: ${experiment_name} trial_name: ${trial_name} @@ -33,6 +34,7 @@ rollout: dump_to_file: true request_timeout: 120.0 setup_timeout: 300.0 + admin_api_key: sk-test123456 openai: mode: online export_style: individual diff --git a/examples/experimental/inference_service/tau2_rollout.py b/examples/experimental/inference_service/tau2_rollout.py index d7c6828216..b78701b023 100644 --- a/examples/experimental/inference_service/tau2_rollout.py +++ b/examples/experimental/inference_service/tau2_rollout.py @@ -1,4 +1,4 @@ -"""Rollout-only script for Tau2 benchmark using GatewayInferenceController. +"""Rollout-only script for Tau2 benchmark using RolloutControllerV2. This example demonstrates how to run rollouts (data generation) without training, using the gateway HTTP stack to route inference requests. @@ -13,6 +13,7 @@ import sys import warnings +from copy import deepcopy from dataclasses import asdict, dataclass, field from typing import Any @@ -27,11 +28,8 @@ TrainDatasetConfig, load_expr_config, ) -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.utils import logging @@ -138,7 +136,7 @@ def get_tau2_dataset( @dataclass class Tau2GatewayRolloutConfig(BaseExperimentConfig): - """Configuration for Tau2 rollout-only with GatewayInferenceController.""" + """Configuration for Tau2 rollout-only with RolloutControllerV2.""" gconfig: GenerationHyperparameters = field( default_factory=GenerationHyperparameters @@ -160,7 +158,8 @@ def main(argv: list[str]) -> None: config, _ = load_expr_config(argv, Tau2GatewayRolloutConfig) econfig = config.econfig - rollout_cfg = config.rollout + rollout_cfg = deepcopy(config.rollout) + rollout_cfg.model = config.model_path # --- Dataset --- train_dataset = get_tau2_dataset( @@ -178,26 +177,6 @@ def main(argv: list[str]) -> None: num_workers=0, # in-process; tau2 dataset is lightweight ) - # --- Build GatewayControllerConfig from YAML rollout section --- - ctrl_config = GatewayControllerConfig( - tokenizer_path=config.tokenizer_path, - model_path=config.model_path, - consumer_batch_size=rollout_cfg.consumer_batch_size, - max_concurrent_rollouts=rollout_cfg.max_concurrent_rollouts, - max_head_offpolicyness=rollout_cfg.max_head_offpolicyness, - queue_size=rollout_cfg.queue_size, - enable_rollout_tracing=rollout_cfg.enable_rollout_tracing, - fileroot=rollout_cfg.fileroot, - experiment_name=rollout_cfg.experiment_name, - trial_name=rollout_cfg.trial_name, - dump_to_file=rollout_cfg.dump_to_file, - backend=rollout_cfg.backend, - scheduling_spec=rollout_cfg.scheduling_spec, - setup_timeout=rollout_cfg.setup_timeout, - request_timeout=rollout_cfg.request_timeout, - openai=rollout_cfg.openai, - ) - # --- Scheduler --- from areal.infra.scheduler.local import LocalScheduler from areal.infra.scheduler.slurm import SlurmScheduler @@ -219,7 +198,7 @@ def main(argv: list[str]) -> None: else: raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") - ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) + ctrl = RolloutControllerV2(config=rollout_cfg, scheduler=scheduler) ctrl.initialize( role="rollout", server_args=server_args, diff --git a/examples/experimental/inference_service/tau2_rollout.yaml b/examples/experimental/inference_service/tau2_rollout.yaml index 13b19a5904..8fd69c6191 100644 --- a/examples/experimental/inference_service/tau2_rollout.yaml +++ b/examples/experimental/inference_service/tau2_rollout.yaml @@ -29,6 +29,7 @@ gconfig: temperature: 1.0 rollout: + _version: v2 backend: "sglang:d2" experiment_name: ${experiment_name} trial_name: ${trial_name} @@ -40,6 +41,7 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: false + admin_api_key: rollout-admin openai: mode: inline tool_call_parser: qwen25 diff --git a/tests/experimental/agent_service/test_controller.py b/tests/experimental/agent_service/test_controller.py index 376bed71c6..bacbd1bbe9 100644 --- a/tests/experimental/agent_service/test_controller.py +++ b/tests/experimental/agent_service/test_controller.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for AgentServiceController. +"""Unit tests for AgentController. All Guard HTTP interactions are mocked — no real processes or servers. Tests cover: initialize, destroy, scale_up, scale_down, and error handling. @@ -14,12 +14,8 @@ import pytest -from areal.experimental.agent_service.controller.config import ( - AgentServiceControllerConfig, -) -from areal.experimental.agent_service.controller.controller import ( - AgentServiceController, -) +from areal.api.cli_args import AgentConfig +from areal.experimental.agent_service.controller.controller import AgentController CTRL = "areal.experimental.agent_service.controller.controller" @@ -83,7 +79,7 @@ def _mock_health_response(active_sessions: int = 0) -> MagicMock: @pytest.fixture() def config(): - return AgentServiceControllerConfig( + return AgentConfig( agent_cls_path="my.Agent", admin_api_key="test-key", num_pairs=2, @@ -116,7 +112,7 @@ def mock_post(url, **kwargs): class TestConstruction: def test_construction(self, config): scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) assert ctrl.router_addr == "" assert ctrl.gateway_addr == "" assert ctrl.pairs == {} @@ -129,7 +125,7 @@ def test_initialize_forks_router_pairs_gateway(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090"), ("10.0.0.2", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() scheduler.create_workers.assert_called_once() @@ -148,7 +144,7 @@ def test_scale_up_adds_pairs(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl.pairs) == 0 @@ -178,7 +174,7 @@ def mock_post(url, **kwargs): mock_requests.RequestException = Exception scheduler = _make_scheduler(("g0", "8090"), ("g1", "8091")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() guards_called.clear() @@ -197,7 +193,7 @@ def test_scale_down_removes_newest_first(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl.pairs) == 3 @@ -214,7 +210,7 @@ def test_destroy_clears_everything(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert len(ctrl._forked_services) > 0 @@ -246,7 +242,7 @@ def mock_post(url, **kwargs): mock_requests.RequestException = Exception scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() ctrl.destroy() @@ -274,7 +270,7 @@ def mock_get(url, **kwargs): mock_requests.get = mock_get scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() health_call_count = 0 @@ -301,7 +297,7 @@ def counting_get(url, **kwargs): mock_requests.get = counting_get scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() pre_get_count = get_count @@ -318,7 +314,7 @@ def test_health_monitor_starts_and_stops(self, mock_requests, config): _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert ctrl._health_thread is not None assert ctrl._health_thread.is_alive() @@ -333,7 +329,7 @@ def test_health_monitor_disabled_when_interval_zero(self, mock_requests, config) _setup_mock_requests(mock_requests) scheduler = _make_scheduler(("10.0.0.1", "8090")) - ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl = AgentController(config=config, scheduler=scheduler) ctrl.initialize() assert ctrl._health_thread is None diff --git a/tests/experimental/agent_service/test_guard.py b/tests/experimental/agent_service/test_guard.py index e5a82deac3..674808d618 100644 --- a/tests/experimental/agent_service/test_guard.py +++ b/tests/experimental/agent_service/test_guard.py @@ -2,7 +2,7 @@ Tests that the base guard routes are available on the agent guard app. The agent_blueprint has been removed in v2 — all orchestration logic -now lives in AgentServiceController. +now lives in AgentController. Test structure mirrors ``tests/experimental/inference_service/test_guard.py``. """ diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 5e61f768f1..877635a34a 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -1,4 +1,4 @@ -"""Tests for GatewayInferenceController.""" +"""Tests for RolloutControllerV2.""" from __future__ import annotations @@ -8,48 +8,57 @@ import httpx import pytest -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.inference_service.controller.workflow import ( InferenceServiceWorkflow, ) # ============================================================================= -# GatewayControllerConfig +# InferenceEngineConfig # ============================================================================= -class TestGatewayControllerConfig: +class TestInferenceEngineConfigForInferenceService: def test_defaults(self): - cfg = GatewayControllerConfig() - assert cfg.admin_api_key is None + cfg = InferenceEngineConfig(backend="sglang:d1") + assert cfg.admin_api_key == "areal-admin-key" assert cfg.model == "default" - assert cfg.consumer_batch_size == 16 + assert cfg.consumer_batch_size == 1 assert cfg.max_concurrent_rollouts is None assert cfg.max_head_offpolicyness == 0 assert cfg.enable_rollout_tracing is False assert cfg.set_reward_finish_timeout == 0.0 + assert cfg.session_timeout_seconds == 3600.0 + assert cfg.stale_session_cleanup_interval_seconds == 60.0 + assert cfg.stale_session_dump_path == "" def test_custom_values(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="custom-key", consumer_batch_size=32, max_concurrent_rollouts=64, max_head_offpolicyness=5, set_reward_finish_timeout=3.0, + session_timeout_seconds=120.0, + stale_session_cleanup_interval_seconds=15.0, + stale_session_dump_path="/tmp/stale-dumps", ) assert cfg.admin_api_key == "custom-key" assert cfg.consumer_batch_size == 32 assert cfg.max_concurrent_rollouts == 64 assert cfg.max_head_offpolicyness == 5 assert cfg.set_reward_finish_timeout == 3.0 + assert cfg.session_timeout_seconds == 120.0 + assert cfg.stale_session_cleanup_interval_seconds == 15.0 + assert cfg.stale_session_dump_path == "/tmp/stale-dumps" def test_scheduling_fields(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", request_timeout=60.0, setup_timeout=600.0, ) @@ -57,30 +66,31 @@ def test_scheduling_fields(self): assert cfg.setup_timeout == 600.0 def test_dump_to_file_defaults_to_false(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") assert cfg.dump_to_file is False # ============================================================================= -# GatewayInferenceController — workflow resolution helpers +# RolloutControllerV2 — workflow resolution helpers # ============================================================================= class TestControllerWorkflowResolution: def test_resolve_workflow_with_instance(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) with pytest.raises(TypeError, match=r"callable run\(\) method"): controller._resolve_workflow(12345) def test_resolve_workflow_none_creates_online_inference_service_workflow(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" resolved = controller._resolve_workflow( @@ -94,11 +104,12 @@ def test_resolve_workflow_none_creates_online_inference_service_workflow(self): assert resolved.timeout == 3.0 def test_resolve_workflow_agent_class_creates_offline_workflow(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" class MockAgent: @@ -115,17 +126,17 @@ async def run(self, data, **kwargs): assert isinstance(resolved.agent, MockAgent) def test_resolve_should_accept_fn_none(self): - assert GatewayInferenceController._resolve_should_accept_fn(None) is None + assert RolloutControllerV2._resolve_should_accept_fn(None) is None def test_resolve_should_accept_fn_callable(self): fn = lambda x: True # noqa: E731 - assert GatewayInferenceController._resolve_should_accept_fn(fn) is fn + assert RolloutControllerV2._resolve_should_accept_fn(fn) is fn def test_resolve_workflow_with_agent_class(self): """Test _resolve_workflow wraps agent-like classes in InferenceServiceWorkflow.""" - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" class MockAgent: @@ -141,8 +152,8 @@ async def run(self, data, **kwargs): assert hasattr(resolved, "arun_episode") def test_resolve_workflow_agent_class_without_gateway_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) @@ -154,8 +165,8 @@ async def run(self, data, **kwargs): controller._resolve_workflow(MockAgent, workflow_kwargs={}) def test_resolve_workflow_rollout_workflow_instance_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -172,8 +183,8 @@ def test_resolve_workflow_rollout_workflow_instance_raises(self): controller._resolve_workflow(workflow) def test_resolve_workflow_rollout_workflow_class_raises(self): - controller = GatewayInferenceController( - config=GatewayControllerConfig(admin_api_key="test-key"), + controller = RolloutControllerV2( + config=InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -188,11 +199,11 @@ def test_resolve_workflow_rollout_workflow_class_raises(self): # ============================================================================= -# GatewayInferenceController — API surface +# RolloutControllerV2 — API surface # ============================================================================= -class TestGatewayInferenceControllerAPISurface: +class TestRolloutControllerV2APISurface: def test_has_all_public_methods(self): methods = [ "initialize", @@ -216,7 +227,7 @@ def test_has_all_public_methods(self): "start_proxy_gateway", ] for m in methods: - assert hasattr(GatewayInferenceController, m), f"Missing method: {m}" + assert hasattr(RolloutControllerV2, m), f"Missing method: {m}" def test_has_properties(self): properties = [ @@ -228,35 +239,38 @@ def test_has_properties(self): "worker_ids", ] for p in properties: - assert hasattr(GatewayInferenceController, p), f"Missing property: {p}" + assert hasattr(RolloutControllerV2, p), f"Missing property: {p}" def test_not_subclass_of_rollout_controller(self): - """GatewayInferenceController must NOT be a subclass of RolloutController.""" + """RolloutControllerV2 must NOT be a subclass of RolloutController.""" # Verify it doesn't inherit from any class except object - bases = GatewayInferenceController.__bases__ + bases = RolloutControllerV2.__bases__ assert bases == (object,), f"Unexpected bases: {bases}" # ============================================================================= -# GatewayInferenceController — construction + state +# RolloutControllerV2 — construction + state # ============================================================================= -class TestGatewayInferenceControllerConstruction: +class TestRolloutControllerV2Construction: def test_admin_api_key_none_raises(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") + cfg.admin_api_key = "" with pytest.raises(ValueError, match="admin_api_key must be set"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_model_empty_raises(self): - cfg = GatewayControllerConfig(admin_api_key="test-key", model="") + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-key", model="" + ) with pytest.raises(ValueError, match="model must not be empty"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_constructor(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) assert controller.config is cfg assert controller.scheduler is scheduler @@ -268,63 +282,63 @@ def test_constructor(self): assert controller.worker_ids == {} def test_admin_api_key_defaults(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) assert controller.config.admin_api_key == "test-key" def test_version_management_without_services(self): """set_version / get_version work even without gateway services.""" - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # No gateway services started, but version management is local controller._version = 42 assert controller.get_version() == 42 def test_export_stats_returns_dict(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) stats = controller.export_stats() assert isinstance(stats, dict) def test_start_proxy_is_noop(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Should not raise controller.start_proxy() controller.start_proxy_gateway() def test_proxy_gateway_addr(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Before initialize, proxy_gateway_addr returns the empty _gateway_addr assert controller.proxy_gateway_addr == "" def test_callback_addr_formats_ipv6_hostport(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "2001:db8::10" controller._callback_port = 19000 assert controller.callback_addr == "[2001:db8::10]:19000" def test_workflow_executor_raises_before_init(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) with pytest.raises(RuntimeError, match="initialize"): _ = controller.workflow_executor def test_config_perf_tracer_is_noop(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) # Should not raise controller.config_perf_tracer() controller.save_perf_tracer() @@ -343,7 +357,8 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy scheduler = MagicMock() scheduler.get_workers.return_value = [worker] - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", tokenizer_path="mock-tokenizer", request_timeout=15.0, set_reward_finish_timeout=7.5, @@ -357,7 +372,7 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy ), admin_api_key="test-admin-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 @@ -383,17 +398,76 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy assert "--callback-server-addr" in data_proxy_cmd assert "http://127.0.0.1:19000" in data_proxy_cmd + @pytest.mark.asyncio + async def test_async_initialize_passes_stale_cleanup_settings_to_data_proxy( + self, + ): + from areal.api.cli_args import SchedulingSpec + from areal.api.io_struct import LocalInfServerInfo + + worker = MagicMock() + worker.ip = "127.0.0.1" + worker.worker_ports = [18000] + + scheduler = MagicMock() + scheduler.get_workers.return_value = [worker] + + cfg = InferenceEngineConfig( + backend="sglang:d1", + tokenizer_path="mock-tokenizer", + request_timeout=15.0, + session_timeout_seconds=321.0, + stale_session_cleanup_interval_seconds=22.5, + stale_session_dump_path="/tmp/stale-session-dumps", + scheduling_spec=( + SchedulingSpec( + gpu=0, + cpu=1, + mem=1, + cmd="python -m areal.experimental.inference_service.guard", + ), + ), + admin_api_key="test-admin-key", + ) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) + controller._callback_host = "127.0.0.1" + controller._callback_port = 19000 + + with patch.object(controller, "_fork_on_guard") as mock_fork: + mock_fork.side_effect = [ + ("127.0.0.1", 18081), + ("127.0.0.1", 18082), + ("127.0.0.1", 18080), + ] + + await controller._async_initialize( + server_args=None, + server_infos=[ + LocalInfServerInfo( + host="127.0.0.1", port=30000, process=MagicMock() + ) + ], + ) + + data_proxy_cmd = mock_fork.call_args_list[1].kwargs["raw_cmd"] + assert "--session-timeout-seconds" in data_proxy_cmd + assert "321.0" in data_proxy_cmd + assert "--stale-session-cleanup-interval-seconds" in data_proxy_cmd + assert "22.5" in data_proxy_cmd + assert "--stale-session-dump-path" in data_proxy_cmd + assert "/tmp/stale-session-dumps" in data_proxy_cmd + # ============================================================================= -# GatewayInferenceController — gateway HTTP helpers +# RolloutControllerV2 — gateway HTTP helpers # ============================================================================= -class TestGatewayInferenceControllerHTTP: +class TestRolloutControllerV2HTTP: def test_gateway_http_post_raises_on_failure(self): - cfg = GatewayControllerConfig(admin_api_key="test-key") + cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://127.0.0.1:19999" with pytest.raises(RuntimeError, match="Failed to POST"): controller._gateway_http_post("/test", {"key": "value"}) @@ -404,11 +478,12 @@ def test_gateway_http_post_sends_auth(self, mock_post): mock_resp.status_code = 200 mock_post.return_value = mock_resp - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="my-secret-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://127.0.0.1:8080" controller._gateway_http_post("/test_endpoint", {"data": 1}) @@ -422,11 +497,12 @@ def test_gateway_http_post_sends_auth(self, mock_post): class TestOnlineCallbackFlow: @pytest.mark.asyncio async def test_online_callback_without_waiter_buffers_export_request(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() try: async with httpx.AsyncClient() as client: @@ -443,11 +519,12 @@ async def test_online_callback_without_waiter_buffers_export_request(self): @pytest.mark.asyncio async def test_online_callback_settles_waiter_once(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -470,11 +547,12 @@ async def test_online_callback_settles_waiter_once(self): @pytest.mark.asyncio async def test_online_callback_invalid_payload_keeps_waiter_pending(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -497,11 +575,12 @@ async def test_online_callback_invalid_payload_keeps_waiter_pending(self): @pytest.mark.asyncio async def test_cancelled_waiter_buffers_completed_online_result(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-admin-key", ) scheduler = MagicMock() - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._start_online_callback_server() waiter_task = asyncio.create_task( @@ -644,38 +723,39 @@ async def run(self, data, **kwargs): class TestMultiNodeConfig: def test_n_gpus_per_node_default_is_none(self): - cfg = GatewayControllerConfig() + cfg = InferenceEngineConfig(backend="sglang:d1") assert cfg.n_gpus_per_node is None def test_n_gpus_per_node_custom(self): - cfg = GatewayControllerConfig(n_gpus_per_node=4) + cfg = InferenceEngineConfig(backend="sglang:d1", n_gpus_per_node=4) assert cfg.n_gpus_per_node == 4 def test_n_gpus_per_node_zero_raises(self): - cfg = GatewayControllerConfig( - n_gpus_per_node=0, backend="sglang:d1t8", admin_api_key="test-key" + cfg = InferenceEngineConfig( + n_gpus_per_node=1, backend="sglang:d1t8", admin_api_key="test-key" ) + cfg.n_gpus_per_node = 0 with pytest.raises(ValueError, match="n_gpus_per_node must be >= 1"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_gpus_not_divisible_raises(self): - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( n_gpus_per_node=3, backend="sglang:d1t8", admin_api_key="test-key" ) with pytest.raises(ValueError, match="must be divisible by n_gpus_per_node"): - GatewayInferenceController(config=cfg, scheduler=MagicMock()) + RolloutControllerV2(config=cfg, scheduler=MagicMock()) def test_single_node_backward_compat(self): - cfg = GatewayControllerConfig(backend="sglang:d2t4", admin_api_key="test-key") - controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + cfg = InferenceEngineConfig(backend="sglang:d2t4", admin_api_key="test-key") + controller = RolloutControllerV2(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 1 def test_multi_node_valid_config(self): # tp=16, n_gpus_per_node=8 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( n_gpus_per_node=8, backend="sglang:d1t16", admin_api_key="test-key" ) - controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + controller = RolloutControllerV2(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 2 @pytest.mark.asyncio @@ -698,14 +778,14 @@ async def test_async_initialize_multinode_worker_count(self): scheduler.get_workers.return_value = [worker0] # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( tokenizer_path="mock-tokenizer", backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), admin_api_key="test-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 @@ -756,14 +836,14 @@ async def test_async_initialize_multinode_fork_path(self): scheduler.get_workers.return_value = [worker0, worker1] # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( tokenizer_path="mock-tokenizer", backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), admin_api_key="test-key", ) - controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller = RolloutControllerV2(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" controller._callback_port = 19000 diff --git a/tests/experimental/inference_service/test_controller_integration.py b/tests/experimental/inference_service/test_controller_integration.py index cdaa0b1868..380694ddb9 100644 --- a/tests/experimental/inference_service/test_controller_integration.py +++ b/tests/experimental/inference_service/test_controller_integration.py @@ -1,4 +1,4 @@ -"""Integration tests for GatewayInferenceController with real SGLang servers. +"""Integration tests for RolloutControllerV2 with real SGLang servers. Requires GPU and a model. Marked @pytest.mark.slow to exclude from default CI. Run manually: @@ -6,8 +6,8 @@ The test launches: 1. A real SGLang server (GPU subprocess) - 2. Module-scoped LocalScheduler / GatewayInferenceController fixtures - 3. A GatewayInferenceController that spins up Gateway, Router, and Data Proxy +2. Module-scoped LocalScheduler / RolloutControllerV2 fixtures +3. A RolloutControllerV2 that spins up Gateway, Router, and Data Proxy micro-services in background threads. """ @@ -121,14 +121,12 @@ def _make_gateway_controller_config( online_mode: bool = False, set_reward_finish_timeout: float = 0.0, ): - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec - return GatewayControllerConfig( + return InferenceEngineConfig( + backend="sglang:d1", tokenizer_path=model_path, - model_path=model_path, + model=model_path, set_reward_finish_timeout=set_reward_finish_timeout, scheduling_spec=( SchedulingSpec( @@ -189,17 +187,17 @@ def _export_trajectory_with_retry( @pytest.fixture(scope="module") def gateway_controller(sglang_server, model_path, tmp_path_factory): - """Create and initialize a GatewayInferenceController, yield it, then destroy.""" + """Create and initialize a RolloutControllerV2, yield it, then destroy.""" if not has_gpu(): pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler(tmp_path_factory, "gateway_controller") config = _make_gateway_controller_config(model_path) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout", @@ -219,14 +217,14 @@ def gateway_controller_online(sglang_server, model_path, tmp_path_factory): pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_online" ) config = _make_gateway_controller_config(model_path, online_mode=True) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout", @@ -246,7 +244,7 @@ def gateway_controller_with_reward_timeout(sglang_server, model_path, tmp_path_f pytest.skip("GPU required") from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) local_scheduler = _make_local_scheduler( @@ -256,7 +254,7 @@ def gateway_controller_with_reward_timeout(sglang_server, model_path, tmp_path_f model_path, set_reward_finish_timeout=3.0, ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-timeout", @@ -777,7 +775,7 @@ def test_offline_export_applies_discount_after_multiple_rewards_in_same_trajecto @pytest.fixture(scope="module") def gateway_controller_full_init(model_path, tmp_path_factory): - """Create a GatewayInferenceController that launches SGLang via the full init path. + """Create a RolloutControllerV2 that launches SGLang via the full init path. Unlike ``gateway_controller`` which passes pre-existing ``server_infos``, this fixture lets the controller create RPC workers, create @@ -786,17 +784,14 @@ def gateway_controller_full_init(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, + model=model_path, backend="sglang:d1", scheduling_spec=( SchedulingSpec( @@ -810,6 +805,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): ) server_args = { + "model_path": model_path, "skip_tokenizer_init": True, "mem_fraction_static": 0.15, } @@ -817,7 +813,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-full", server_args=server_args, @@ -1075,17 +1071,14 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, + model=model_path, backend="vllm:d1", scheduling_spec=( SchedulingSpec( @@ -1099,13 +1092,14 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): ) server_args = { + "model": model_path, "gpu_memory_utilization": 0.15, } local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vllm" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vllm", server_args=server_args, @@ -1280,17 +1274,14 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=vlm_model_path, - model_path=vlm_model_path, + model=vlm_model_path, backend="sglang:d1", scheduling_spec=( SchedulingSpec( @@ -1306,10 +1297,14 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vlm_sglang" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vlm-sglang", - server_args={"skip_tokenizer_init": True, "mem_fraction_static": 0.25}, + server_args={ + "model_path": vlm_model_path, + "skip_tokenizer_init": True, + "mem_fraction_static": 0.25, + }, ) try: @@ -1324,17 +1319,14 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import SchedulingSpec - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) + from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) - config = GatewayControllerConfig( + config = InferenceEngineConfig( tokenizer_path=vlm_model_path, - model_path=vlm_model_path, + model=vlm_model_path, backend="vllm:d1", scheduling_spec=( SchedulingSpec( @@ -1350,7 +1342,7 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): local_scheduler = _make_local_scheduler( tmp_path_factory, "gateway_controller_full_init_vlm_vllm" ) - ctrl = GatewayInferenceController(config=config, scheduler=local_scheduler) + ctrl = RolloutControllerV2(config=config, scheduler=local_scheduler) ctrl.initialize( role="rollout-vlm-vllm", server_args={"gpu_memory_utilization": 0.25}, diff --git a/tests/experimental/inference_service/test_controller_version.py b/tests/experimental/inference_service/test_controller_version.py index b2d90c3a5d..255bc6319a 100644 --- a/tests/experimental/inference_service/test_controller_version.py +++ b/tests/experimental/inference_service/test_controller_version.py @@ -1,4 +1,4 @@ -"""Unit tests for GatewayInferenceController version management. +"""Unit tests for RolloutControllerV2 version management. Tests set_version and get_version with mocked HTTP calls. """ @@ -8,12 +8,9 @@ import asyncio from unittest.mock import AsyncMock, MagicMock, patch -from areal.api.cli_args import SchedulingSpec -from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, -) +from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) # ============================================================================= @@ -25,17 +22,18 @@ def _make_controller( gateway_addr: str = "", worker_ids: dict[str, str] | None = None, version: int = 0, -) -> GatewayInferenceController: +) -> RolloutControllerV2: """Create a controller with minimal config and manually injected state. Does NOT call initialize() — internal fields are set directly. """ - cfg = GatewayControllerConfig( + cfg = InferenceEngineConfig( + backend="sglang:d1", admin_api_key="test-key", scheduling_spec=(SchedulingSpec(),), ) scheduler = MagicMock() - ctrl = GatewayInferenceController(config=cfg, scheduler=scheduler) + ctrl = RolloutControllerV2(config=cfg, scheduler=scheduler) ctrl._gateway_addr = gateway_addr ctrl._worker_ids = worker_ids or {} ctrl._version = version @@ -48,7 +46,7 @@ def _make_controller( class TestControllerSetVersion: - """Test GatewayInferenceController.set_version.""" + """Test RolloutControllerV2.set_version.""" def test_set_version_updates_local(self): ctrl = _make_controller() @@ -93,7 +91,7 @@ def test_set_version_broadcasts_to_all_workers(self): class TestControllerGetVersion: - """Test GatewayInferenceController.get_version.""" + """Test RolloutControllerV2.get_version.""" def test_get_version_returns_local(self): ctrl = _make_controller(version=0) diff --git a/tests/experimental/inference_service/test_data_proxy_chat.py b/tests/experimental/inference_service/test_data_proxy_chat.py index c53abfba91..733ad9a411 100644 --- a/tests/experimental/inference_service/test_data_proxy_chat.py +++ b/tests/experimental/inference_service/test_data_proxy_chat.py @@ -2,12 +2,17 @@ from __future__ import annotations +import asyncio +import json import time +from pathlib import Path from unittest.mock import AsyncMock, MagicMock import httpx +import orjson import pytest import pytest_asyncio +import torch from areal.experimental.inference_service.data_proxy.app import ( _flush_ready_trajectories, @@ -17,7 +22,12 @@ from areal.experimental.inference_service.data_proxy.session import ( SessionData, SessionStore, + _dump_stale_session_exports, ) +from areal.experimental.openai.proxy.server import deserialize_interactions +from areal.experimental.openai.types import InteractionWithTokenLogpReward +from areal.infra.rpc import rtensor as rtensor_storage +from areal.infra.rpc.rtensor import RTensor # ============================================================================= # Fixtures @@ -118,6 +128,15 @@ async def _mock_create(*, areal_cache=None, **kwargs): return mock_client +@pytest.fixture(autouse=True) +def clear_rtensor_storage(): + rtensor_storage._storage.clear() + rtensor_storage._storage_stats.clear() + yield + rtensor_storage._storage.clear() + rtensor_storage._storage_stats.clear() + + @pytest_asyncio.fixture async def client(config, mock_tokenizer, mock_areal_client): """Create app with mocked deps and yield an httpx async client.""" @@ -1197,3 +1216,294 @@ async def test_full_session_lifecycle(client, mock_areal_client): assert resp.status_code == 200 data = resp.json() assert "interactions" in data + + +@pytest.mark.asyncio +async def test_cleanup_stale_sessions_dumps_and_removes_ready_session(client, tmp_path): + app = client._transport.app + app.state.config.stale_session_dump_path = str(tmp_path) + app.state.config.session_timeout_seconds = 1.0 + + start = await client.post( + "/rl/start_session", + json={"task_id": "stale-ready"}, + headers=admin_headers(), + ) + session_api_key = start.json()["api_key"] + session_id = start.json()["session_id"] + + await client.post( + "/chat/completions", + json={"model": "sglang", "messages": [{"role": "user", "content": "hi"}]}, + headers=session_headers(session_api_key), + ) + reward_resp = await client.post( + "/rl/set_reward", + json={"reward": 1.0}, + headers=session_headers(session_api_key), + ) + assert reward_resp.status_code == 200 + + session = app.state.session_store.get_session(session_id) + assert session is not None + session._last_access_time -= 10.0 + + await app.state.session_store.cleanup_stale( + timeout_seconds=app.state.config.session_timeout_seconds, + stale_session_dump_path=app.state.config.stale_session_dump_path, + serving_addr=app.state.config.serving_addr, + ) + + dump_file = Path(tmp_path) / f"{session_id}-trajectory-0.json" + assert dump_file.exists() + dumped = orjson.loads(dump_file.read_bytes()) + assert list(dumped) == ["chatcmpl-test0"] + restored = deserialize_interactions(dumped) + assert isinstance(restored["chatcmpl-test0"]._cache["input_ids"], torch.Tensor) + assert app.state.session_store.get_session(session_id) is None + + +@pytest.mark.asyncio +async def test_cleanup_stale_sessions_dumps_active_interactions_and_clears_shards( + client, tmp_path +): + app = client._transport.app + app.state.config.stale_session_dump_path = str(tmp_path) + app.state.config.session_timeout_seconds = 1.0 + + start = await client.post( + "/rl/start_session", + json={"task_id": "stale-active"}, + headers=admin_headers(), + ) + session_api_key = start.json()["api_key"] + session_id = start.json()["session_id"] + + await client.post( + "/chat/completions", + json={"model": "sglang", "messages": [{"role": "user", "content": "hi"}]}, + headers=session_headers(session_api_key), + ) + + session = app.state.session_store.get_session(session_id) + assert session is not None + session._last_access_time -= 10.0 + + await app.state.session_store.cleanup_stale( + timeout_seconds=app.state.config.session_timeout_seconds, + stale_session_dump_path=app.state.config.stale_session_dump_path, + serving_addr=app.state.config.serving_addr, + ) + + dump_file = Path(tmp_path) / f"{session_id}-active.json" + assert dump_file.exists() + dumped = orjson.loads(dump_file.read_bytes()) + assert list(dumped) == ["chatcmpl-test0"] + restored = deserialize_interactions(dumped) + input_ids = restored["chatcmpl-test0"]._cache["input_ids"] + assert isinstance(input_ids, torch.Tensor) + assert not input_ids.is_meta + assert app.state.session_store.get_session(session_id) is None + assert rtensor_storage._storage == {} + + +@pytest.mark.asyncio +async def test_cleanup_stale_sessions_localizes_existing_rtensors_before_dump( + client, tmp_path +): + app = client._transport.app + app.state.config.stale_session_dump_path = str(tmp_path) + app.state.config.session_timeout_seconds = 1.0 + app.state.config.serving_addr = "" + + session_id, _ = app.state.session_store.start_session("stale-rtensor") + session = app.state.session_store.get_session(session_id) + assert session is not None + + interaction = InteractionWithTokenLogpReward( + messages=[{"role": "user", "content": "hi"}], + output_message_list=[{"role": "assistant", "content": "hello"}], + ) + interaction.interaction_id = "rtensor-interaction" + interaction._cache = RTensor.remotize( + { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[True, True, True]]), + "loss_mask": torch.tensor([[0, 1, 1]]), + "logprobs": torch.tensor([[0.0, -0.5, -0.1]]), + "versions": torch.tensor([[0, 0, 0]]), + "rewards": torch.tensor([1.0]), + }, + node_addr="", + ) + session.active_completions[interaction.interaction_id] = interaction + session._last_access_time -= 10.0 + + await app.state.session_store.cleanup_stale( + timeout_seconds=app.state.config.session_timeout_seconds, + stale_session_dump_path=app.state.config.stale_session_dump_path, + serving_addr=app.state.config.serving_addr, + ) + + dump_file = Path(tmp_path) / f"{session_id}-active.json" + assert dump_file.exists() + restored = deserialize_interactions(orjson.loads(dump_file.read_bytes())) + cache = restored[interaction.interaction_id]._cache + assert isinstance(cache["input_ids"], torch.Tensor) + assert not cache["input_ids"].is_meta + assert rtensor_storage._storage == {} + + +@pytest.mark.asyncio +async def test_session_store_cleanup_stale_dumps_and_removes_session(tmp_path): + store = SessionStore() + session_id, _ = store.start_session("store-stale") + session = store.get_session(session_id) + assert session is not None + + interaction = InteractionWithTokenLogpReward( + messages=[{"role": "user", "content": "hi"}], + output_message_list=[{"role": "assistant", "content": "hello"}], + ) + interaction.interaction_id = "store-interaction" + interaction._cache = { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[True, True, True]]), + "loss_mask": torch.tensor([[0, 1, 1]]), + "logprobs": torch.tensor([[0.0, -0.5, -0.1]]), + "versions": torch.tensor([[0, 0, 0]]), + "rewards": torch.tensor([1.0]), + } + session.active_completions[interaction.interaction_id] = interaction + session._last_access_time -= 10.0 + + await store.cleanup_stale( + timeout_seconds=1.0, + stale_session_dump_path=str(tmp_path), + serving_addr="", + ) + + dump_file = Path(tmp_path) / f"{session_id}-active.json" + assert dump_file.exists() + restored = deserialize_interactions(orjson.loads(dump_file.read_bytes())) + assert isinstance( + restored[interaction.interaction_id]._cache["input_ids"], torch.Tensor + ) + assert store.get_session(session_id) is None + + +@pytest.mark.asyncio +async def test_cleanup_stale_removes_session_before_async_dump(monkeypatch, tmp_path): + store = SessionStore() + session_id, api_key = store.start_session("store-stale-race") + session = store.get_session(session_id) + assert session is not None + + interaction = InteractionWithTokenLogpReward( + messages=[{"role": "user", "content": "hi"}], + output_message_list=[{"role": "assistant", "content": "hello"}], + ) + interaction.interaction_id = "race-interaction" + interaction._cache = { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[True, True, True]]), + "loss_mask": torch.tensor([[0, 1, 1]]), + "logprobs": torch.tensor([[0.0, -0.5, -0.1]]), + "versions": torch.tensor([[0, 0, 0]]), + "rewards": torch.tensor([1.0]), + } + session.active_completions[interaction.interaction_id] = interaction + session._last_access_time -= 10.0 + + started = asyncio.Event() + continue_dump = asyncio.Event() + + async def _blocking_dump(*args, **kwargs): + started.set() + await continue_dump.wait() + + monkeypatch.setattr( + "areal.experimental.inference_service.data_proxy.session._dump_stale_session_exports", + _blocking_dump, + ) + + cleanup_task = asyncio.create_task( + store.cleanup_stale( + timeout_seconds=1.0, + stale_session_dump_path=str(tmp_path), + serving_addr="", + ) + ) + + await started.wait() + + assert store.get_session(session_id) is None + assert store.get_session_by_api_key(api_key) is None + + continue_dump.set() + await cleanup_task + + +def test_snapshot_cleanup_exports_holds_session_lock_during_export(monkeypatch): + session = SessionData("snapshot-lock") + + interaction = InteractionWithTokenLogpReward( + messages=[{"role": "user", "content": "hi"}], + output_message_list=[{"role": "assistant", "content": "hello"}], + ) + interaction.interaction_id = "snapshot-interaction" + session.active_completions[interaction.interaction_id] = interaction + + export_lock_states: list[bool] = [] + original_export = session.active_completions.export_interactions + + def _recording_export(*args, **kwargs): + export_lock_states.append(session._lock.locked()) + return original_export(*args, **kwargs) + + monkeypatch.setattr( + session.active_completions, + "export_interactions", + _recording_export, + ) + + exports = session.snapshot_cleanup_exports(discount=1.0, style="individual") + + assert exports + assert export_lock_states == [True] + + +@pytest.mark.asyncio +async def test_dump_stale_session_exports_offloads_json_dumps(monkeypatch, tmp_path): + interaction = InteractionWithTokenLogpReward( + messages=[{"role": "user", "content": "hi"}], + output_message_list=[{"role": "assistant", "content": "hello"}], + ) + interaction.interaction_id = "dump-interaction" + interaction._cache = { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[True, True, True]]), + "loss_mask": torch.tensor([[0, 1, 1]]), + "logprobs": torch.tensor([[0.0, -0.5, -0.1]]), + "versions": torch.tensor([[0, 0, 0]]), + "rewards": torch.tensor([1.0]), + } + + recorded: list[object] = [] + original_to_thread = asyncio.to_thread + + async def _recording_to_thread(func, /, *args, **kwargs): + recorded.append(func) + return await original_to_thread(func, *args, **kwargs) + + monkeypatch.setattr(asyncio, "to_thread", _recording_to_thread) + + await _dump_stale_session_exports( + "dump-session", + [(None, {interaction.interaction_id: interaction})], + stale_session_dump_path=str(tmp_path), + serving_addr="", + ) + + assert recorded == [json.dumps] + assert (Path(tmp_path) / "dump-session-active.json").exists() diff --git a/tests/experimental/inference_service/test_examples.py b/tests/experimental/inference_service/test_examples.py index a91c552e56..763e4e5ee7 100644 --- a/tests/experimental/inference_service/test_examples.py +++ b/tests/experimental/inference_service/test_examples.py @@ -79,7 +79,8 @@ def wait_for_pattern(process, pattern, timeout, raise_on_exit=True): f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", f"actor.path={model_path}", "scheduler.type=local", - f"rollout.openai.admin_api_key={admin_api_key}", + f"rollout.admin_api_key={admin_api_key}", + "rollout._version=v2", "stats_logger.wandb.mode=disabled", ] @@ -245,6 +246,7 @@ def test_tau2_rollout(tmp_path_factory): "cluster.n_gpus_per_node=2", f"cluster.fileroot={str(experiments_path)}", f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", + "rollout.admin_api_key=test-admin-key", f"model_path={model_path}", "train_dataset.batch_size=2", "train_dataset.path=tau2/train", diff --git a/tests/experimental/weight_update/test_disk_integration.py b/tests/experimental/weight_update/test_disk_integration.py index fc4fde69f5..6824ba2894 100644 --- a/tests/experimental/weight_update/test_disk_integration.py +++ b/tests/experimental/weight_update/test_disk_integration.py @@ -360,15 +360,13 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): from areal.api import FinetuneSpec from areal.api.cli_args import ( + InferenceEngineConfig, OptimizerConfig, SchedulingSpec, TrainEngineConfig, ) - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -385,9 +383,8 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): scheduler = _make_local_scheduler(tmp, "disk_e2e", gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_half}", scheduling_spec=( SchedulingSpec( @@ -400,7 +397,7 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=f"fsdp:d{n_half}", @@ -430,7 +427,7 @@ def test_disk_e2e_weight_update(n_gpus, tmp_path_factory): # -- 1. SGLang via inference controller ---------------------------- inf_ctrl.initialize( role="rollout", - server_args={"mem_fraction_static": 0.7}, + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, ) inf_worker_urls = list(inf_ctrl._inf_addrs) diff --git a/tests/experimental/weight_update/test_nccl_integration.py b/tests/experimental/weight_update/test_nccl_integration.py index cd6f443a0c..7ceca68423 100644 --- a/tests/experimental/weight_update/test_nccl_integration.py +++ b/tests/experimental/weight_update/test_nccl_integration.py @@ -308,15 +308,13 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): from areal.api import FinetuneSpec from areal.api.cli_args import ( + InferenceEngineConfig, OptimizerConfig, SchedulingSpec, TrainEngineConfig, ) - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, - ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -332,9 +330,8 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): scheduler = _make_local_scheduler(tmp, "e2e", gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_half}", scheduling_spec=( SchedulingSpec( @@ -347,7 +344,7 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=f"fsdp:d{n_half}", @@ -377,7 +374,7 @@ def test_awex_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): # -- 1. SGLang via inference controller ---------------------------- inf_ctrl.initialize( role="rollout", - server_args={"mem_fraction_static": 0.7}, + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, ) inf_worker_urls = list(inf_ctrl._inf_addrs) @@ -510,12 +507,14 @@ def _run_megatron_awex_e2e( model_path: str | None = None, ): from areal.api import FinetuneSpec - from areal.api.cli_args import OptimizerConfig, SchedulingSpec, TrainEngineConfig - from areal.experimental.inference_service.controller.config import ( - GatewayControllerConfig, + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, ) from areal.experimental.inference_service.controller.controller import ( - GatewayInferenceController, + RolloutControllerV2, ) from areal.experimental.training_service.controller.controller import ( GatewayTrainController, @@ -530,9 +529,8 @@ def _run_megatron_awex_e2e( model_path = model_path or _get_test_model_path() scheduler = _make_local_scheduler(tmp, tag, gpu_devices=list(range(n_gpus))) - inf_config = GatewayControllerConfig( + inf_config = InferenceEngineConfig( tokenizer_path=model_path, - model_path=model_path, backend=f"sglang:d{n_infer}", scheduling_spec=( SchedulingSpec( @@ -545,7 +543,7 @@ def _run_megatron_awex_e2e( setup_timeout=300.0, admin_api_key="test-admin", ) - inf_ctrl = GatewayInferenceController(config=inf_config, scheduler=scheduler) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) train_config = TrainEngineConfig( backend=backend, @@ -571,7 +569,10 @@ def _run_megatron_awex_e2e( wu_ctrl: WeightUpdateController | None = None try: - inf_ctrl.initialize(role="rollout", server_args={"mem_fraction_static": 0.7}) + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + ) inf_worker_urls = list(inf_ctrl._inf_addrs) for url in inf_worker_urls: