Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,68 @@ def __post_init__(self):
raise ValueError("admin_api_key must not be empty or whitespace-only")


@dataclass
class AgentConfig:
"""Configuration for the experimental agent service controller."""

agent_cls_path: str = field(
default="",
metadata={
"help": "Fully-qualified import path for the AgentRunnable implementation."
},
)
admin_api_key: str = field(
default="areal-agent-admin",
metadata={"help": "Shared admin API key for agent-service inter-service auth."},
)
num_pairs: int = field(
default=1,
metadata={"help": "Number of Worker+DataProxy pairs to launch on initialize."},
)
setup_timeout: float = field(
default=120.0,
metadata={
"help": "Timeout in seconds waiting for each service to become healthy."
},
)
health_poll_interval: float = field(
default=5.0,
metadata={
"help": "Seconds between pair health polls; 0 disables health monitoring."
},
)
drain_timeout: float = field(
default=30.0,
metadata={
"help": "Seconds to wait for active sessions to drain before force-killing a pair."
},
)
log_level: str = field(
default="info",
metadata={"help": "Log level for spawned agent-service micro-services."},
)
env: dict[str, str] = field(
default_factory=dict,
metadata={
"help": "Extra environment variables passed to all forked child processes."
},
)

def __post_init__(self) -> None:
if not self.agent_cls_path:
raise ValueError("agent_cls_path must be a non-empty import path")
if self.num_pairs < 0:
raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}")
if self.setup_timeout <= 0:
raise ValueError(
f"setup_timeout must be positive, got {self.setup_timeout}"
)
if self.drain_timeout < 0:
raise ValueError(
f"drain_timeout must be non-negative, got {self.drain_timeout}"
)


@dataclass
class InferenceEngineConfig:
"""Configuration for inference servers, including offpolicyness control."""
Expand Down Expand Up @@ -2081,13 +2143,115 @@ class InferenceEngineConfig:
},
)

# v2 controller options
_version: str = field(
default="v1",
metadata={
"help": "Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2.",
"choices": ["v1", "v2"],
},
)
model: str = field(
default="default",
metadata={"help": "Model name exposed through the inference-service gateway."},
)
routing_strategy: str = field(
default="round_robin",
metadata={"help": "Routing strategy for the inference-service router."},
)
poll_interval: float = field(
default=5.0,
metadata={
"help": "Health-poll interval in seconds for the inference-service router."
},
)
set_reward_finish_timeout: float = field(
default=0.0,
metadata={
"help": "Timeout in seconds to wait for additional reward updates before finalizing a session."
},
)
session_timeout_seconds: float = field(
default=3600.0,
metadata={
"help": "Timeout in seconds before an inactive inference-service session is considered stale and cleaned up."
},
)
stale_session_cleanup_interval_seconds: float = field(
default=60.0,
metadata={
"help": "Polling interval in seconds for stale-session cleanup in inference-service data proxies."
},
)
stale_session_dump_path: str = field(
default="",
metadata={
"help": "Optional directory path where stale-session trajectory dumps are written before cleanup."
},
)
log_level: str = field(
default="info",
metadata={"help": "Log level for inference-service micro-services."},
)
admin_api_key: str = field(
default="areal-admin-key",
metadata={
"help": "Admin API key used by the inference-service gateway, router, and data proxies."
},
)
api_url: str | None = field(
default=None,
metadata={
"help": "External OpenAI-compatible base URL for inference-service external model mode."
},
)
provider_api_key: str | None = field(
default=None,
metadata={"help": "API key for the external OpenAI-compatible provider."},
)
n_gpus_per_node: int | None = field(
default=None,
metadata={
"help": "GPUs per physical node for multinode inference-service launch."
},
)

def __post_init__(self):
"""Validate scheduling_spec length."""
if len(self.scheduling_spec) not in (1, 2):
raise ValueError(
f"scheduling_spec must contain 1 or 2 SchedulingSpec, "
f"got {len(self.scheduling_spec)}"
)
if self._version not in ("v1", "v2"):
raise ValueError(
f"_version must be either 'v1' or 'v2', got '{self._version}'"
)
if self.n_gpus_per_node is not None and self.n_gpus_per_node < 1:
raise ValueError(
f"n_gpus_per_node must be >= 1, got {self.n_gpus_per_node}"
)
if self.session_timeout_seconds <= 0:
raise ValueError(
"session_timeout_seconds must be positive, "
f"got {self.session_timeout_seconds}"
)
if self.stale_session_cleanup_interval_seconds <= 0:
raise ValueError(
"stale_session_cleanup_interval_seconds must be positive, "
f"got {self.stale_session_cleanup_interval_seconds}"
)
if not self.admin_api_key or not self.admin_api_key.strip():
raise ValueError("admin_api_key must not be empty or whitespace-only")
if (
self._version == "v2"
and self.openai is not None
and self.openai.admin_api_key != "areal-admin-key"
):
logger.warning(
"rollout.openai.admin_api_key is ignored by rollout controller v2; "
"use rollout.admin_api_key instead."
)


@dataclass
Expand Down
5 changes: 2 additions & 3 deletions areal/experimental/agent_service/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,8 @@ areal/experimental/agent_service/
├── protocol.py # Gateway protocol frame types
├── types.py # AgentRequest, AgentResponse, EventEmitter, AgentRunnable
├── controller/
│ ├── __init__.py # AgentServiceController, AgentServiceControllerConfig
│ ├── config.py # AgentServiceControllerConfig dataclass
│ └── controller.py # AgentServiceController orchestrator
│ ├── __init__.py # AgentController export
│ └── controller.py # AgentController orchestrator
├── guard/
│ ├── __init__.py # Module docstring
│ ├── __main__.py # python -m areal.experimental.agent_service.guard
Expand Down
2 changes: 1 addition & 1 deletion areal/experimental/agent_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

Submodules
----------
- ``controller`` — :class:`AgentServiceController` orchestrator
- ``controller`` — :class:`AgentController` orchestrator
- ``gateway`` — public HTTP/WebSocket entry point
- ``router`` — session-affine routing
- ``data_proxy`` — stateful session proxy
Expand Down
9 changes: 5 additions & 4 deletions areal/experimental/agent_service/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

"""Agent Service Controller — orchestrator for agent micro-services."""

from .config import AgentServiceControllerConfig
from .controller import AgentServiceController
from areal.api.cli_args import AgentConfig

from .controller import AgentController

__all__ = [
"AgentServiceController",
"AgentServiceControllerConfig",
"AgentController",
"AgentConfig",
]
63 changes: 0 additions & 63 deletions areal/experimental/agent_service/controller/config.py

This file was deleted.

16 changes: 7 additions & 9 deletions areal/experimental/agent_service/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

"""AgentServiceController — orchestrates agent service micro-services via Guards.
"""AgentController — orchestrates agent service micro-services via Guards.

Mirrors the architecture of
:class:`~areal.experimental.inference_service.controller.controller.GatewayInferenceController`:
:class:`~areal.experimental.inference_service.controller.controller.RolloutControllerV2`:
Guard workers are created via the Scheduler, then the controller forks
Router, Worker+DataProxy pairs, and Gateway onto them via HTTP API.

Expand All @@ -12,7 +12,7 @@
from areal.infra.scheduler.local import LocalScheduler

scheduler = LocalScheduler(...)
controller = AgentServiceController(config, scheduler)
controller = AgentController(config, scheduler)
controller.initialize()
# ... run traffic ...
controller.scale_up(2) # add 2 Worker+DataProxy pairs
Expand All @@ -32,16 +32,14 @@

import requests

from areal.experimental.agent_service.controller.config import (
AgentServiceControllerConfig,
)
from areal.api.cli_args import AgentConfig
from areal.utils import logging
from areal.utils.network import format_hostport

if TYPE_CHECKING:
from areal.api.scheduler_api import Scheduler, Worker

logger = logging.getLogger("AgentServiceController")
logger = logging.getLogger("AgentController")

_GUARD_ROLE = "agent-guard"
_UNREGISTER_RETRIES = 3
Expand All @@ -60,7 +58,7 @@ class _WorkerPair:
worker_addr: str


class AgentServiceController:
class AgentController:
"""Orchestrator for the Agent Service micro-service stack.

Parameters
Expand All @@ -73,7 +71,7 @@ class AgentServiceController:

def __init__(
self,
config: AgentServiceControllerConfig,
config: AgentConfig,
scheduler: Scheduler,
) -> None:
self.config = config
Expand Down
8 changes: 3 additions & 5 deletions areal/experimental/inference_service/controller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

from areal.experimental.inference_service.controller.config import (
GatewayControllerConfig,
)
from areal.api.cli_args import InferenceEngineConfig
from areal.experimental.inference_service.controller.controller import (
GatewayInferenceController,
RolloutControllerV2,
)

__all__ = ["GatewayControllerConfig", "GatewayInferenceController"]
__all__ = ["InferenceEngineConfig", "RolloutControllerV2"]
Loading
Loading