Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions areal/experimental/inference_service/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@ def areal_launch_server(server_args) -> None:
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process

# ---- BEGIN AREAL ----
from areal.experimental.inference_service.sglang.awex import (
register_awex_endpoints,
)
from areal.experimental.inference_service.sglang.awex import register_awex_endpoints
from areal.experimental.inference_service.sglang.rdt import register_rdt_endpoints
from areal.experimental.inference_service.sglang.rpc_proxy import RpcProxy
from areal.experimental.inference_service.sglang.scheduler import (
areal_run_scheduler_process,
create_result_ipc,
get_weight_update_backend,
)
# ---- END AREAL ----

# ---- BEGIN AREAL ----
result_ipc = create_result_ipc()
backend = getattr(server_args, "weight_update_backend", None)
if backend is None:
backend = get_weight_update_backend()
result_ipc = create_result_ipc(backend)
# ---- END AREAL ----

(
Expand All @@ -60,7 +63,10 @@ def areal_launch_server(server_args) -> None:

# ---- BEGIN AREAL ----
rpc_proxy = RpcProxy(port_args, result_ipc)
register_awex_endpoints(app, rpc_proxy)
if backend == "awex":
register_awex_endpoints(app, rpc_proxy)
elif backend == "rdt":
register_rdt_endpoints(app, rpc_proxy)
# ---- END AREAL ----

try:
Expand Down
92 changes: 92 additions & 0 deletions areal/experimental/inference_service/sglang/rdt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
"""RDT HTTP endpoints for IW weight update.

Reference: areal.experimental.inference_service.sglang.awex
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

from areal.utils import logging

if TYPE_CHECKING:
from areal.experimental.inference_service.sglang.rpc_proxy import RpcProxy

logger = logging.getLogger("RDTIWEndpoints")


def register_rdt_endpoints(app: FastAPI, rpc_proxy: RpcProxy) -> None:
"""Register ``/rdt/*`` weight-update endpoints on IW's FastAPI app.

Each endpoint dispatches to all scheduler processes via RpcProxy,
using collective_rpc_with_result or collective_rpc.

Args:
app: FastAPI application
rpc_proxy: RpcProxy for scheduler subprocess communication
"""

@app.get("/rdt/report_parallelism")
async def report_parallelism() -> JSONResponse:
"""Report IW parallelism strategy for TransferPlan building."""
try:
result = rpc_proxy.collective_rpc_with_result("rdt_report_parallelism")
if not isinstance(result, dict):
err_msg = f"Expected dict from rdt_report_parallelism, got {type(result).__name__}"
logger.error(err_msg)
return JSONResponse(status_code=500, content={"error": err_msg})
return JSONResponse(content=result)
except Exception as e:
logger.error("Failed to report parallelism: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/rdt/report_weight_meta")
async def report_weight_meta() -> JSONResponse:
"""Report IW weight metadata for TransferPlan building."""
try:
result = rpc_proxy.collective_rpc_with_result("rdt_report_weight_meta")
return JSONResponse(content={"status": "ok", "meta": result})
except Exception as e:
logger.error("Failed to report weight meta: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/rdt/init_weight_update_group")
async def init_weight_update_group(request: Request) -> JSONResponse:
"""Initialize RDT weight update group.

Args passed via JSON body:
pair_name: TW-IW pair identifier
kv_store_url: Gateway KV store URL
tw_actor_bytes_b64_list: Base64-encoded TW actor handle bytes
infer_world_size: Total IW world size
train_world_size: Total TW world size
num_engines: Number of IW engines
transfer_rank: IW's transfer rank
"""
try:
data = await request.json()
rpc_proxy.collective_rpc("rdt_init_weight_update_group", **data)
return JSONResponse(content={"status": "ok"})
except Exception as e:
logger.error("Failed to init RDT weight update group: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/rdt/execute_weight_update")
async def execute_weight_update(request: Request) -> JSONResponse:
"""Execute RDT weight update - pull from TW via Ray RPC.

Args passed via JSON body:
version: Weight version number (optional, default 0)
"""
try:
data = await request.json()
version = data.get("version", 0)
rpc_proxy.collective_rpc("rdt_execute_weight_update", version=version)
return JSONResponse(content={"status": "ok", "version": version})
except Exception as e:
logger.error("Failed to execute RDT weight update: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})
115 changes: 110 additions & 5 deletions areal/experimental/inference_service/sglang/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
"""AwexSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler."""
"""AwexSchedulerBridge/RDTSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler."""

from __future__ import annotations

Expand All @@ -12,9 +12,16 @@
import zmq
from sglang.srt.server_args import PortArgs, ServerArgs

from areal.experimental.weight_update import (
BACKEND_AWEX,
BACKEND_RDT,
WEIGHT_UPDATE_BACKEND_ENV,
get_weight_update_backend,
)
from areal.infra.rpc.serialization import serialize_value

RESULT_IPC_ENV = "AREAL_AWEX_RESULT_IPC"
RDT_RESULT_IPC_ENV = "AREAL_RDT_RESULT_IPC"


class AwexSchedulerBridge:
Expand Down Expand Up @@ -118,6 +125,84 @@ def awex_randomize_parameters(self) -> None:
self._require_adapter().randomize_parameters()


class RDTSchedulerBridge:
"""Compose RDT weight-update capabilities onto a plain Scheduler instance.

Lifecycle:
1. Created after ``Scheduler.__init__()`` in :func:`areal_run_scheduler_process`
2. :meth:`bind` attaches ``rdt_*`` methods to the scheduler via ``setattr``
3. ``handle_rpc_request`` dispatches via ``getattr(self, method)`` and finds them
4. Methods delegate to :class:`RDTSGLangAdapter` for actual work
5. Data-returning methods push results via ZMQ PUSH (tp_rank 0, dp_rank 0 only)

No inheritance. No monkey-patch. The scheduler instance remains a plain
``sglang.srt.managers.scheduler.Scheduler``.
"""

def __init__(self, scheduler: Any) -> None:
self._scheduler = scheduler
self._adapter: Any | None = None
self._result_push: zmq.Socket | None = None

result_ipc = os.environ.get(RDT_RESULT_IPC_ENV)
if (
result_ipc
and scheduler.tp_rank == 0
and (getattr(scheduler, "dp_rank", None) is None or scheduler.dp_rank == 0)
):
ctx = zmq.Context(1)
self._result_push = ctx.socket(zmq.PUSH)
self._result_push.connect(result_ipc)

def bind(self) -> None:
"""Attach ``rdt_*`` methods to the scheduler instance."""
methods = [
"rdt_report_weight_meta",
"rdt_report_parallelism",
"rdt_init_weight_update_group",
"rdt_execute_weight_update",
]
for name in methods:
setattr(self._scheduler, name, getattr(self, name))

def _require_adapter(self) -> Any:
if self._adapter is None:
from areal.experimental.weight_update.rdt.sglang_adapter import (
RDTSGLangAdapter,
)

self._adapter = RDTSGLangAdapter(self._scheduler)
return self._adapter

def _push_result(self, result: Any) -> None:
if self._result_push is not None:
self._result_push.send_pyobj(result)

def rdt_report_weight_meta(self) -> None:
adapter = self._require_adapter()
local_meta = adapter.get_weight_metadata()
s = self._scheduler

if s.tp_size > 1:
gathered: list[list] = [[] for _ in range(s.tp_size)]
dist.all_gather_object(gathered, local_meta, group=s.tp_cpu_group)
all_meta: list = []
for rank_meta in gathered:
all_meta.extend(rank_meta)
self._push_result(serialize_value(all_meta))
else:
self._push_result(serialize_value(local_meta))

def rdt_report_parallelism(self) -> None:
self._push_result(self._require_adapter().parallelism_strategy)

def rdt_init_weight_update_group(self, **kwargs: Any) -> None:
self._require_adapter().rdt_init_weight_update_group(**kwargs)

def rdt_execute_weight_update(self, version: int = 0) -> None:
self._require_adapter().rdt_execute_weight_update(version)


# ---------------------------------------------------------------------------
# Duplicated from sglang.srt.managers.scheduler.run_scheduler_process
# (SGLang commit pinned in this repo).
Expand Down Expand Up @@ -216,7 +301,11 @@ def areal_run_scheduler_process(
)

# ---- BEGIN AREAL ----
AwexSchedulerBridge(scheduler).bind()
backend = get_weight_update_backend()
if backend == BACKEND_AWEX:
AwexSchedulerBridge(scheduler).bind()
elif backend == BACKEND_RDT:
RDTSchedulerBridge(scheduler).bind()
PPSchedulerBridge(scheduler, server_args).bind()
# ---- END AREAL ----

Expand All @@ -229,7 +318,23 @@ def areal_run_scheduler_process(
parent_process.send_signal(signal.SIGQUIT)


def create_result_ipc() -> str:
path = f"ipc://{tempfile.mktemp(prefix='areal_result_')}"
os.environ[RESULT_IPC_ENV] = path
def create_result_ipc(backend: str) -> str:
"""Create result IPC path for given backend.

Sets environment variable for scheduler subprocess to read.

Args:
backend: "awex" or "rdt"

Returns:
IPC path string
"""
path = f"ipc://{tempfile.mktemp(prefix=f'areal_{backend}_result_')}"

if backend == BACKEND_AWEX:
os.environ[RESULT_IPC_ENV] = path
elif backend == BACKEND_RDT:
os.environ[RDT_RESULT_IPC_ENV] = path

os.environ[WEIGHT_UPDATE_BACKEND_ENV] = backend
return path
27 changes: 20 additions & 7 deletions areal/experimental/training_service/worker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from areal.experimental.training_service.worker.awex import create_awex_blueprint
from areal.experimental.training_service.worker.config import TrainWorkerConfig
from areal.experimental.training_service.worker.engine import create_engine_module
from areal.experimental.training_service.worker.rdt import create_rdt_blueprint
from areal.experimental.weight_update import get_weight_update_backend
from areal.infra.platforms import current_platform
from areal.infra.rpc.serialization import deserialize_value, serialize_value
from areal.utils import logging
Expand Down Expand Up @@ -198,14 +200,25 @@ def _get_node_addr() -> str:
)
)

app.register_blueprint(
create_awex_blueprint(
flask_module=flask,
get_engine=_get_engine,
submit_to_engine_thread=_submit_to_engine_thread,
run_endpoint=_run_endpoint,
backend = get_weight_update_backend()
if backend == "awex":
app.register_blueprint(
create_awex_blueprint(
flask_module=flask,
get_engine=_get_engine,
submit_to_engine_thread=_submit_to_engine_thread,
run_endpoint=_run_endpoint,
)
)
elif backend == "rdt":
app.register_blueprint(
create_rdt_blueprint(
flask_module=flask,
get_engine=_get_engine,
submit_to_engine_thread=_submit_to_engine_thread,
run_endpoint=_run_endpoint,
)
)
)

from areal.infra.rpc.guard.data_blueprint import data_bp

Expand Down
Loading