Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
logger = logging.getLogger("RolloutControllerV2")

_MAX_COMPLETED_ONLINE_RESULTS = 1024
_DEFAULT_SERVICE_LOG_LEVEL = "info"
_DEFAULT_SERVICE_LOG_LEVEL = "warning"


@dataclass
Expand Down
45 changes: 45 additions & 0 deletions areal/experimental/inference_service/sglang/awex.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,48 @@ async def randomize_parameters() -> JSONResponse:
except Exception as e:
logger.error("Failed to randomize parameters: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/awex/init_colocate_weight_update")
async def init_colocate_weight_update(request: Request) -> JSONResponse:
try:
data = await request.json()
rpc_proxy.collective_rpc("awex_init_colocate_weight_update", **data)
return JSONResponse(content={"status": "ok"})
except Exception as e:
logger.error("Failed to init colocate weight update: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/awex/execute_colocate_weight_update")
async def execute_colocate_weight_update(request: Request) -> JSONResponse:
try:
data = await request.json()
version = data.get("version", 0)
rpc_proxy.collective_rpc(
"awex_execute_colocate_weight_update", version=version
)
return JSONResponse(content={"status": "ok", "version": version})
except Exception as e:
logger.error("Failed to execute colocate weight update: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/awex/release_memory")
async def release_memory(request: Request) -> JSONResponse:
try:
data = await request.json()
tags = data.get("tags")
rpc_proxy.collective_rpc("awex_release_memory", tags=tags)
return JSONResponse(content={"status": "ok"})
except Exception as e:
logger.error("Failed to release memory: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/awex/resume_memory")
async def resume_memory(request: Request) -> JSONResponse:
try:
data = await request.json()
tags = data.get("tags")
rpc_proxy.collective_rpc("awex_resume_memory", tags=tags)
return JSONResponse(content={"status": "ok"})
except Exception as e:
logger.error("Failed to resume memory: %s", e)
return JSONResponse(status_code=500, content={"error": str(e)})
16 changes: 16 additions & 0 deletions areal/experimental/inference_service/sglang/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def bind(self) -> None:
"awex_batch_isend_irecv",
"awex_get_parameters",
"awex_randomize_parameters",
"awex_init_colocate_weight_update",
"awex_execute_colocate_weight_update",
"awex_release_memory",
"awex_resume_memory",
]
for name in methods:
setattr(self._scheduler, name, getattr(self, name))
Expand Down Expand Up @@ -117,6 +121,18 @@ def awex_get_parameters(
def awex_randomize_parameters(self) -> None:
self._require_adapter().randomize_parameters()

def awex_init_colocate_weight_update(self, **kwargs: Any) -> None:
self._require_adapter().init_colocate_weight_update(**kwargs)

def awex_execute_colocate_weight_update(self, version: int = 0) -> None:
self._require_adapter().execute_colocate_weight_update(version)

def awex_release_memory(self, tags: list[str] | None = None) -> None:
self._require_adapter().release_memory(tags)

def awex_resume_memory(self, tags: list[str] | None = None) -> None:
self._require_adapter().resume_memory(tags)


# ---------------------------------------------------------------------------
# Duplicated from sglang.srt.managers.scheduler.run_scheduler_process
Expand Down
59 changes: 59 additions & 0 deletions areal/experimental/training_service/worker/awex.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,65 @@ def action():
return_result=False,
)

@bp.route("/init_colocate_weight_update", methods=["POST"])
def init_colocate_weight_update():
data = flask_module.request.get_json(force=True)

def action():
adapter = _require_adapter()
adapter.init_colocate_weight_update(**data)

return run_endpoint(
"init_colocate_weight_update",
lambda: submit_to_engine_thread("init_colocate_weight_update", action),
return_result=False,
)

@bp.route("/execute_colocate_weight_update", methods=["POST"])
def execute_colocate_weight_update():
data = flask_module.request.get_json(force=True)
version = data.get("version", 0)

def action():
adapter = _require_adapter()
adapter.execute_colocate_weight_update(version)

return run_endpoint(
"execute_colocate_weight_update",
lambda: submit_to_engine_thread("execute_colocate_weight_update", action),
return_result=False,
)

@bp.route("/release_memory", methods=["POST"])
def release_memory():
data = flask_module.request.get_json(force=True)
tags = data.get("tags")

def action():
adapter = _require_adapter()
adapter.release_memory(tags)

return run_endpoint(
"release_memory",
lambda: submit_to_engine_thread("release_memory", action),
return_result=False,
)

@bp.route("/resume_memory", methods=["POST"])
def resume_memory():
data = flask_module.request.get_json(force=True)
tags = data.get("tags")

def action():
adapter = _require_adapter()
adapter.resume_memory(tags)

return run_endpoint(
"resume_memory",
lambda: submit_to_engine_thread("resume_memory", action),
return_result=False,
)

@bp.route("/debug/get_parameters", methods=["POST"])
def get_parameters():
"""Save local shard parameters to a file for test validation."""
Expand Down
Loading
Loading