diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index 0a9af9ad77..155e318a02 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -38,7 +38,7 @@ logger = logging.getLogger("RolloutControllerV2") _MAX_COMPLETED_ONLINE_RESULTS = 1024 -_DEFAULT_SERVICE_LOG_LEVEL = "info" +_DEFAULT_SERVICE_LOG_LEVEL = "warning" @dataclass diff --git a/areal/experimental/inference_service/sglang/awex.py b/areal/experimental/inference_service/sglang/awex.py index 4c9171f824..6bd3f3c015 100644 --- a/areal/experimental/inference_service/sglang/awex.py +++ b/areal/experimental/inference_service/sglang/awex.py @@ -94,3 +94,48 @@ async def randomize_parameters() -> JSONResponse: except Exception as e: logger.error("Failed to randomize parameters: %s", e) return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/awex/init_colocate_weight_update") + async def init_colocate_weight_update(request: Request) -> JSONResponse: + try: + data = await request.json() + rpc_proxy.collective_rpc("awex_init_colocate_weight_update", **data) + return JSONResponse(content={"status": "ok"}) + except Exception as e: + logger.error("Failed to init colocate weight update: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/awex/execute_colocate_weight_update") + async def execute_colocate_weight_update(request: Request) -> JSONResponse: + try: + data = await request.json() + version = data.get("version", 0) + rpc_proxy.collective_rpc( + "awex_execute_colocate_weight_update", version=version + ) + return JSONResponse(content={"status": "ok", "version": version}) + except Exception as e: + logger.error("Failed to execute colocate weight update: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/awex/release_memory") + async def release_memory(request: Request) -> JSONResponse: + try: + data = await request.json() + tags = data.get("tags") + rpc_proxy.collective_rpc("awex_release_memory", tags=tags) + return JSONResponse(content={"status": "ok"}) + except Exception as e: + logger.error("Failed to release memory: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/awex/resume_memory") + async def resume_memory(request: Request) -> JSONResponse: + try: + data = await request.json() + tags = data.get("tags") + rpc_proxy.collective_rpc("awex_resume_memory", tags=tags) + return JSONResponse(content={"status": "ok"}) + except Exception as e: + logger.error("Failed to resume memory: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/areal/experimental/inference_service/sglang/scheduler.py b/areal/experimental/inference_service/sglang/scheduler.py index 7d8b3a86d9..927162d826 100644 --- a/areal/experimental/inference_service/sglang/scheduler.py +++ b/areal/experimental/inference_service/sglang/scheduler.py @@ -62,6 +62,10 @@ def bind(self) -> None: "awex_batch_isend_irecv", "awex_get_parameters", "awex_randomize_parameters", + "awex_init_colocate_weight_update", + "awex_execute_colocate_weight_update", + "awex_release_memory", + "awex_resume_memory", ] for name in methods: setattr(self._scheduler, name, getattr(self, name)) @@ -117,6 +121,18 @@ def awex_get_parameters( def awex_randomize_parameters(self) -> None: self._require_adapter().randomize_parameters() + def awex_init_colocate_weight_update(self, **kwargs: Any) -> None: + self._require_adapter().init_colocate_weight_update(**kwargs) + + def awex_execute_colocate_weight_update(self, version: int = 0) -> None: + self._require_adapter().execute_colocate_weight_update(version) + + def awex_release_memory(self, tags: list[str] | None = None) -> None: + self._require_adapter().release_memory(tags) + + def awex_resume_memory(self, tags: list[str] | None = None) -> None: + self._require_adapter().resume_memory(tags) + # --------------------------------------------------------------------------- # Duplicated from sglang.srt.managers.scheduler.run_scheduler_process diff --git a/areal/experimental/training_service/worker/awex.py b/areal/experimental/training_service/worker/awex.py index f1d13cbb18..e206810b83 100644 --- a/areal/experimental/training_service/worker/awex.py +++ b/areal/experimental/training_service/worker/awex.py @@ -113,6 +113,65 @@ def action(): return_result=False, ) + @bp.route("/init_colocate_weight_update", methods=["POST"]) + def init_colocate_weight_update(): + data = flask_module.request.get_json(force=True) + + def action(): + adapter = _require_adapter() + adapter.init_colocate_weight_update(**data) + + return run_endpoint( + "init_colocate_weight_update", + lambda: submit_to_engine_thread("init_colocate_weight_update", action), + return_result=False, + ) + + @bp.route("/execute_colocate_weight_update", methods=["POST"]) + def execute_colocate_weight_update(): + data = flask_module.request.get_json(force=True) + version = data.get("version", 0) + + def action(): + adapter = _require_adapter() + adapter.execute_colocate_weight_update(version) + + return run_endpoint( + "execute_colocate_weight_update", + lambda: submit_to_engine_thread("execute_colocate_weight_update", action), + return_result=False, + ) + + @bp.route("/release_memory", methods=["POST"]) + def release_memory(): + data = flask_module.request.get_json(force=True) + tags = data.get("tags") + + def action(): + adapter = _require_adapter() + adapter.release_memory(tags) + + return run_endpoint( + "release_memory", + lambda: submit_to_engine_thread("release_memory", action), + return_result=False, + ) + + @bp.route("/resume_memory", methods=["POST"]) + def resume_memory(): + data = flask_module.request.get_json(force=True) + tags = data.get("tags") + + def action(): + adapter = _require_adapter() + adapter.resume_memory(tags) + + return run_endpoint( + "resume_memory", + lambda: submit_to_engine_thread("resume_memory", action), + return_result=False, + ) + @bp.route("/debug/get_parameters", methods=["POST"]) def get_parameters(): """Save local shard parameters to a file for test validation.""" diff --git a/areal/experimental/weight_update/awex/megatron_adapter.py b/areal/experimental/weight_update/awex/megatron_adapter.py index 95e1fd731e..6002a7fa8b 100644 --- a/areal/experimental/weight_update/awex/megatron_adapter.py +++ b/areal/experimental/weight_update/awex/megatron_adapter.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import gc import os +import threading +import time from typing import TYPE_CHECKING +import httpx import torch import torch.distributed as dist from awex.meta.weight_meta import ( @@ -15,6 +19,10 @@ from awex.sharding.rank_info import RankInfo from awex.transfer.nccl_comm import batch_send_recv, nccl_build_send_ops from awex.transfer.transfer_plan import TransferPlan, TransferPlanBuilder +from awex.util.tensor_util import ( + cuda_ipc_serialize, + group_tensors_by_shape_and_dtype, +) from areal.experimental.weight_update.awex import fetch_kv_metadata from areal.experimental.weight_update.nccl_group import ( @@ -51,6 +59,13 @@ def __init__(self, engine: MegatronEngine): self._transfer_plan: TransferPlan | None = None self._weights_update_group = None self._transfer_rank: int | None = None + self._offloaded_optimizer_states: dict = {} + self._offloaded_weights: dict[str, torch.Tensor] = {} + self._released_tags: set[str] = set() + self._colocate_lock = threading.Lock() + self._colocate_admin_api_key: str = "areal-admin-key" + self._colocate_http_client: httpx.Client | None = None + self._colocate_timeout_s: float = 120.0 @property def parallelism_strategy(self) -> dict: @@ -121,9 +136,16 @@ def get_local_shard_parameters( return result def save_parameters(self, save_path: str, names: list[str] | None = None) -> None: - params = self.get_local_shard_parameters(names) - cpu_params = {k: v.detach().cpu().clone() for k, v in params.items()} - torch.save(cpu_params, save_path) + weights_offloaded = "weights" in self._released_tags + if weights_offloaded: + self.resume_memory(tags=["weights"]) + try: + params = self.get_local_shard_parameters(names) + cpu_params = {k: v.detach().cpu().clone() for k, v in params.items()} + torch.save(cpu_params, save_path) + finally: + if weights_offloaded: + self.release_memory(tags=["weights"]) def init_weight_update_group( self, @@ -194,6 +216,9 @@ def teardown_weight_update_group(self) -> None: self._weights_update_group = None self._transfer_plan = None self._transfer_rank = None + if self._colocate_http_client is not None: + self._colocate_http_client.close() + self._colocate_http_client = None def _build_rank_info(self) -> RankInfo: from megatron.core import parallel_state as mpu @@ -276,3 +301,250 @@ def _iter_hf_params(self): if tie_word_embeddings and hf_name == "lm_head.weight": continue yield hf_name, tensor.detach() + + # ── Colocated weight transfer methods ───────────────────────────────── + + def init_colocate_weight_update( + self, + pair_name: str, + kv_store_url: str, + transfer_rank: int, + infer_world_size: int, + train_world_size: int, + num_engines: int, + master_port: int, + admin_api_key: str = "areal-admin-key", + timeout_s: float = 120.0, + ) -> None: + self._colocate_pair_name = pair_name + self._colocate_kv_store_url = kv_store_url + self._colocate_transfer_rank = transfer_rank + self._colocate_infer_world_size = infer_world_size + self._colocate_admin_api_key = admin_api_key + self._colocate_timeout_s = timeout_s + if self._colocate_http_client is None: + self._colocate_http_client = httpx.Client() + logger.info( + "Initialized colocate weight update for pair '%s', transfer_rank=%d", + pair_name, + transfer_rank, + ) + + def execute_colocate_weight_update(self, version: int) -> None: + with self._colocate_lock: + self._execute_colocate_weight_update_locked(version) + + def _execute_colocate_weight_update_locked(self, version: int) -> None: + kv_store_url = self._colocate_kv_store_url + pair_name = self._colocate_pair_name + transfer_rank = self._colocate_transfer_rank + assert self._colocate_http_client is not None, ( + "init_colocate_weight_update must be called first" + ) + client = self._colocate_http_client + auth_headers = {"Authorization": f"Bearer {self._colocate_admin_api_key}"} + timeout_s = self._colocate_timeout_s + + weights_offloaded = "weights" in self._released_tags + if weights_offloaded: + self.resume_memory(tags=["weights"]) + + params = self.get_local_shard_parameters() + tensors = list(params.values()) + names = list(params.keys()) + + group_tensors, metadata = group_tensors_by_shape_and_dtype(tensors) + torch.cuda.synchronize() + + del tensors + + group_shared = [t.share_memory_() for t in group_tensors] + serialized_weights = cuda_ipc_serialize((group_shared, metadata, names)) + torch.cuda.synchronize() + + kv_key = f"colocate_weights_rank{transfer_rank}_{version}" + + client.put( + f"{kv_store_url}/weight_meta/{pair_name}/{kv_key}", + json={"value": serialized_weights.hex()}, + headers=auth_headers, + timeout=timeout_s, + ) + + logger.info( + "Serialized %d params (%d groups) for colocate transfer v%d, rank %d", + len(names), + len(group_shared), + version, + transfer_rank, + ) + + done_key = f"colocate_done_rank{transfer_rank}_{version}" + deadline = time.monotonic() + timeout_s + poll_count = 0 + last_status = -1 + while time.monotonic() < deadline: + resp = client.get( + f"{kv_store_url}/weight_meta/{pair_name}/{done_key}", + timeout=5.0, + ) + last_status = resp.status_code + if resp.status_code == 200: + break + poll_count += 1 + time.sleep(0.1) + else: + raise TimeoutError( + f"Inference did not signal completion within {timeout_s}s " + f"(waiting_key={done_key}, put_key={kv_key}, " + f"polls={poll_count}, last_status={last_status})" + ) + + del group_shared, group_tensors, serialized_weights + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + if weights_offloaded: + self.release_memory(tags=["weights"]) + + def release_memory(self, tags: list[str] | None = None) -> None: + """Release GPU memory for specified tags by offloading to CPU. + + Supported tags: + - "optimizer": Offload optimizer state tensors (exp_avg, exp_avg_sq, etc.) + - "weights": Offload model parameters + """ + tags = tags or ["optimizer", "weights"] + tags_to_release = [t for t in tags if t not in self._released_tags] + if not tags_to_release: + logger.info("release_memory: tags=%s already released, skipping", tags) + return + + logger.info("release_memory: offloading tags=%s", tags_to_release) + + if "optimizer" in tags_to_release: + self._offload_optimizer_states() + self._released_tags.add("optimizer") + + if "weights" in tags_to_release: + self._offload_model_weights() + self._released_tags.add("weights") + + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + logger.info("release_memory: done for tags=%s", tags_to_release) + + def resume_memory(self, tags: list[str] | None = None) -> None: + """Resume GPU memory for specified tags by reloading from CPU. + + Supported tags: + - "optimizer": Reload optimizer state tensors to GPU + - "weights": Reload model parameters to GPU + """ + tags = tags or ["optimizer", "weights"] + tags_to_resume = [t for t in tags if t in self._released_tags] + if not tags_to_resume: + logger.info("resume_memory: tags=%s not released, skipping", tags) + return + + logger.info("resume_memory: reloading tags=%s", tags_to_resume) + + if "weights" in tags_to_resume: + self._reload_model_weights() + self._released_tags.discard("weights") + + if "optimizer" in tags_to_resume: + self._reload_optimizer_states() + self._released_tags.discard("optimizer") + + torch.cuda.synchronize() + logger.info("resume_memory: done for tags=%s", tags_to_resume) + + def _offload_optimizer_states(self) -> None: + """Move optimizer state tensors to CPU, keeping references for reload.""" + optimizer = self._engine.optimizer + if optimizer is None: + logger.warning("No optimizer found, skipping optimizer offload") + return + + # Megatron's ChainedOptimizer wraps per-model-chunk optimizers; + # each in turn wraps a base torch optimizer holding the state dict. + if hasattr(optimizer, "optimizers"): + inner_optimizers = optimizer.optimizers + else: + inner_optimizers = [optimizer] + logger.warning( + "Optimizer does not have 'optimizers' attribute. " + "Treating it as a single optimizer; offload may be incomplete " + "for non-standard Megatron optimizer structures." + ) + for opt in inner_optimizers: + base_opt = getattr(opt, "optimizer", opt) + for param, state in base_opt.state.items(): + cpu_state: dict[str, torch.Tensor] = {} + for key, val in state.items(): + if isinstance(val, torch.Tensor) and val.is_cuda: + cpu_state[key] = val.detach().to("cpu", non_blocking=True) + state[key] = torch.empty(0, device="cpu") + if cpu_state: + self._offloaded_optimizer_states[param] = cpu_state + + logger.info( + "Offloaded optimizer states for %d params", + len(self._offloaded_optimizer_states), + ) + + def _reload_optimizer_states(self) -> None: + """Restore optimizer state tensors from CPU back to GPU.""" + if not self._offloaded_optimizer_states: + return + + optimizer = self._engine.optimizer + if optimizer is None: + return + + inner_optimizers = getattr(optimizer, "optimizers", [optimizer]) + for opt in inner_optimizers: + base_opt = getattr(opt, "optimizer", opt) + for param, state in base_opt.state.items(): + if param in self._offloaded_optimizer_states: + cpu_state = self._offloaded_optimizer_states[param] + for key, val in cpu_state.items(): + state[key] = val.to(param.device, non_blocking=True) + + self._offloaded_optimizer_states.clear() + logger.info("Reloaded optimizer states to GPU") + + def _offload_model_weights(self) -> None: + """Move model parameters to CPU, keeping references for reload.""" + if self._engine.model is None: + return + + for name, param in self._engine.model.named_parameters(): + if param.is_cuda: + self._offloaded_weights[name] = param.data.detach().to( + "cpu", non_blocking=True + ) + param.data = torch.empty(0, device="cpu") + + logger.info( + "Offloaded %d model weight tensors to CPU", + len(self._offloaded_weights), + ) + + def _reload_model_weights(self) -> None: + """Restore model parameters from CPU back to GPU.""" + if not self._offloaded_weights: + return + if self._engine.model is None: + return + + device = self._engine.device + for name, param in self._engine.model.named_parameters(): + if name in self._offloaded_weights: + param.data = self._offloaded_weights[name].to(device, non_blocking=True) + + self._offloaded_weights.clear() + logger.info("Reloaded model weights to GPU") diff --git a/areal/experimental/weight_update/awex/sglang_adapter.py b/areal/experimental/weight_update/awex/sglang_adapter.py index 86bbe7ac17..4ebcf4c8c5 100644 --- a/areal/experimental/weight_update/awex/sglang_adapter.py +++ b/areal/experimental/weight_update/awex/sglang_adapter.py @@ -2,10 +2,13 @@ # pyright: reportMissingImports=false from __future__ import annotations +import gc import math import os +import time from typing import Any +import httpx import torch import torch.distributed as dist from awex.meta.weight_meta import ( @@ -20,7 +23,12 @@ get_sglang_sharding_strategy, ) from awex.transfer.nccl_comm import batch_send_recv, nccl_build_recv_ops +from awex.transfer.nccl_stream_batch import NcclColocateStreamBatchTransport from awex.transfer.transfer_plan import TransferPlan, TransferPlanBuilder +from awex.util.tensor_util import ( + cuda_ipc_deserialize, + reconstruct_tensors_from_groups, +) from areal.experimental.weight_update.awex import fetch_kv_metadata from areal.experimental.weight_update.inference_adapter import ( @@ -45,6 +53,13 @@ def __init__(self, scheduler: Any): self._transfer_rank: int | None = None self._rank_info: RankInfo | None = None self._parameters: dict[str, torch.Tensor] | None = None + self._released_tags: set[str] = set() + self._colocate_admin_api_key: str = "areal-admin-key" + self._colocate_http_client: httpx.Client | None = None + self._colocate_timeout_s: float = 120.0 + self._colocate_transport = None + self._train_to_infer_device_mapping: dict | None = None + self._infer_to_train_device_mapping: dict | None = None def _get_model(self) -> torch.nn.Module: return self._scheduler.tp_worker.model_runner.model @@ -402,3 +417,225 @@ def teardown_weight_update_group(self) -> None: self._transfer_rank = None self._rank_info = None self._parameters = None + if self._colocate_http_client is not None: + self._colocate_http_client.close() + self._colocate_http_client = None + self._colocate_transport = None + self._train_to_infer_device_mapping = None + self._infer_to_train_device_mapping = None + + # ── Colocated weight transfer methods ───────────────────────────────── + + def init_colocate_weight_update( + self, + pair_name: str, + kv_store_url: str, + transfer_rank: int, + infer_world_size: int, + train_world_size: int, + num_engines: int, + master_port: int, + admin_api_key: str = "areal-admin-key", + timeout_s: float = 120.0, + ) -> None: + if infer_world_size != train_world_size: + raise ValueError( + f"Colocate mode requires infer_world_size == train_world_size. " + f"Got infer_world_size={infer_world_size}, " + f"train_world_size={train_world_size}" + ) + self._colocate_pair_name = pair_name + self._colocate_kv_store_url = kv_store_url + self._transfer_rank = transfer_rank + self._colocate_infer_world_size = infer_world_size + self._colocate_train_world_size = train_world_size + self._colocate_admin_api_key = admin_api_key + self._colocate_timeout_s = timeout_s + if self._colocate_http_client is None: + self._colocate_http_client = httpx.Client() + + infer_meta, train_meta = fetch_kv_metadata(kv_store_url, pair_name) + + builder = TransferPlanBuilder( + infer_world_size=infer_world_size, + train_world_size=train_world_size, + num_infer_engines=num_engines, + ) + + train_to_infer = {} + infer_to_train = {} + for i in range(min(infer_world_size, train_world_size)): + train_rank = infer_world_size + i + train_to_infer[train_rank] = i + infer_to_train[i] = train_rank + self._train_to_infer_device_mapping = train_to_infer + self._infer_to_train_device_mapping = infer_to_train + + self._send_transfer_plan = builder.build_local_transfer_plan( + infer_meta, + train_meta, + global_transfer_rank=infer_to_train[transfer_rank], + ) + self._recv_transfer_plan = builder.build_local_transfer_plan( + infer_meta, + train_meta, + global_transfer_rank=transfer_rank, + ) + + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = str(False) + self._weights_update_group = init_weights_update_group( + master_address="127.0.0.1", + master_port=master_port, + rank=transfer_rank, + world_size=infer_world_size, + group_name=f"awex_colocate_{pair_name}", + role="inference", + ) + + self._colocate_transport = NcclColocateStreamBatchTransport( + transfer_rank, infer_world_size + ) + + logger.info( + "Initialized colocate weight update for pair '%s', " + "transfer_rank=%d, infer_world_size=%d", + pair_name, + transfer_rank, + infer_world_size, + ) + + def execute_colocate_weight_update(self, version: int) -> None: + kv_store_url = self._colocate_kv_store_url + pair_name = self._colocate_pair_name + transfer_rank = self._transfer_rank + assert self._colocate_http_client is not None, ( + "init_colocate_weight_update must be called first" + ) + assert self._infer_to_train_device_mapping is not None + client = self._colocate_http_client + auth_headers = {"Authorization": f"Bearer {self._colocate_admin_api_key}"} + timeout_s = self._colocate_timeout_s + + paired_train_rank = self._infer_to_train_device_mapping[transfer_rank] + kv_key = f"colocate_weights_rank{paired_train_rank}_{version}" + + deadline = time.monotonic() + timeout_s + serialized_hex = None + poll_count = 0 + last_status = -1 + while time.monotonic() < deadline: + resp = client.get( + f"{kv_store_url}/weight_meta/{pair_name}/{kv_key}", + timeout=5.0, + ) + last_status = resp.status_code + if resp.status_code == 200: + serialized_hex = resp.json()["value"] + break + poll_count += 1 + time.sleep(0.1) + if serialized_hex is None: + raise TimeoutError( + f"Training did not put colocate weights within {timeout_s}s " + f"(waiting_key={kv_key}, polls={poll_count}, " + f"last_status={last_status})" + ) + + serialized_weights = bytes.fromhex(serialized_hex) + group_shared, metadata, names = cuda_ipc_deserialize(serialized_weights) + torch.cuda.synchronize() + tensors = reconstruct_tensors_from_groups(group_shared, metadata) + torch.cuda.synchronize() + deserialized_weights = dict(zip(names, tensors)) + + recv_parameters = self.get_local_shard_parameters() + + rank_info = self._build_rank_info() + rank_coordinate = f"infer_{rank_info.global_rank}" + + assert self._colocate_transport is not None + self._colocate_transport.update_weights_in_colocate_mode( + self._train_to_infer_device_mapping, + self._infer_to_train_device_mapping, + transfer_rank, + rank_coordinate, + self._colocate_infer_world_size, + self._send_transfer_plan, + self._recv_transfer_plan, + self._weights_update_group, + deserialized_weights, + recv_parameters, + step_id=version, + ) + + done_key = f"colocate_done_rank{paired_train_rank}_{version}" + client.put( + f"{kv_store_url}/weight_meta/{pair_name}/{done_key}", + json={"value": True}, + headers=auth_headers, + timeout=10.0, + ) + + del deserialized_weights, group_shared, tensors, serialized_weights + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + logger.info( + "Colocate weight update completed for v%d, rank %d", + version, + transfer_rank, + ) + + # Tags understood by SGLang's native release/resume_memory_occupation. + _SGLANG_MEMORY_TAGS = {"kv_cache"} + + def release_memory(self, tags: list[str] | None = None) -> None: + from sglang.srt.managers.io_struct import ReleaseMemoryOccupationReqInput + + native_tags = ( + [t for t in tags if t in self._SGLANG_MEMORY_TAGS] if tags else None + ) + unsupported = ( + [t for t in tags if t not in self._SGLANG_MEMORY_TAGS] if tags else [] + ) + if unsupported: + logger.warning( + "release_memory: tags %s not supported by SGLang adapter " + "(supported: %s), ignoring", + unsupported, + self._SGLANG_MEMORY_TAGS, + ) + if native_tags: + req = ReleaseMemoryOccupationReqInput(tags=native_tags) + self._scheduler.release_memory_occupation(req) + self._released_tags.update(native_tags) + logger.info("release_memory completed with tags=%s", tags) + + def resume_memory(self, tags: list[str] | None = None) -> None: + from sglang.srt.managers.io_struct import ResumeMemoryOccupationReqInput + + native_tags = ( + [ + t + for t in tags + if t in self._SGLANG_MEMORY_TAGS and t in self._released_tags + ] + if tags + else None + ) + unsupported = ( + [t for t in tags if t not in self._SGLANG_MEMORY_TAGS] if tags else [] + ) + if unsupported: + logger.warning( + "resume_memory: tags %s not supported by SGLang adapter " + "(supported: %s), ignoring", + unsupported, + self._SGLANG_MEMORY_TAGS, + ) + if native_tags: + req = ResumeMemoryOccupationReqInput(tags=native_tags) + self._scheduler.resume_memory_occupation(req) + self._released_tags.difference_update(native_tags) + logger.info("resume_memory completed with tags=%s", tags) diff --git a/areal/experimental/weight_update/controller/controller.py b/areal/experimental/weight_update/controller/controller.py index 01dad28a72..a119e1b59f 100644 --- a/areal/experimental/weight_update/controller/controller.py +++ b/areal/experimental/weight_update/controller/controller.py @@ -112,6 +112,7 @@ def connect( save_path: str = "", use_lora: bool = False, lora_name: str = "", + colocate: bool = False, ) -> None: self._pair_name = pair_name resp = self._http.post( @@ -124,11 +125,14 @@ def connect( "save_path": save_path, "use_lora": use_lora, "lora_name": lora_name, + "colocate": colocate, }, timeout=self.config.request_timeout, ) resp.raise_for_status() - logger.info("Connected pair '%s' (mode=%s)", pair_name, mode) + logger.info( + "Connected pair '%s' (mode=%s, colocate=%s)", pair_name, mode, colocate + ) def update_weights(self, version: int) -> WeightUpdateResult: if self._pair_name is None: diff --git a/areal/experimental/weight_update/gateway/app.py b/areal/experimental/weight_update/gateway/app.py index 9a461e2a02..f68539a5b8 100644 --- a/areal/experimental/weight_update/gateway/app.py +++ b/areal/experimental/weight_update/gateway/app.py @@ -22,6 +22,7 @@ from areal.experimental.weight_update.gateway.kv_store import WeightMetaStore from areal.experimental.weight_update.gateway.pair_registry import PairRegistry from areal.utils import logging +from areal.utils.network import find_free_ports logger = logging.getLogger("WeightUpdateGateway") @@ -34,6 +35,7 @@ class ConnectRequest(BaseModel): save_path: str = "" use_lora: bool = False lora_name: str = "" + colocate: bool = False class UpdateWeightsRequest(BaseModel): @@ -196,6 +198,11 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse: train_urls = body.train_worker_urls inference_urls = body.inference_worker_urls + if body.colocate: + return await _connect_colocate( + request, pair_name, train_urls, inference_urls + ) + if body.mode == "disk": if not body.save_path: return JSONResponse( @@ -363,6 +370,212 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse: logger.info("Connected pair '%s'", pair_name) return ConnectResponse(pair_name=pair_name) + async def _connect_colocate( + request: Request, + pair_name: str, + train_urls: list[str], + inference_urls: list[str], + ) -> ConnectResponse: + session = request.app.state.http_session + init_timeout_s = config.init_timeout_s + + train_par, infer_par = await asyncio.gather( + _get_json( + session, + f"{train_urls[0]}/awex/report_parallelism", + init_timeout_s, + ), + _get_json( + session, + f"{inference_urls[0]}/awex/report_parallelism", + init_timeout_s, + ), + ) + + train_world_size = train_par["world_size"] + num_engines = len(inference_urls) + # report_parallelism returns per-instance world_size (e.g. TP size). + # The total inference world for colocate NCCL groups spans all engines. + infer_world_size = infer_par["world_size"] * num_engines + + train_meta_resps, infer_meta_resps = await asyncio.gather( + asyncio.gather( + *[ + _post_json( + session, f"{url}/awex/report_weight_meta", init_timeout_s + ) + for url in train_urls + ] + ), + asyncio.gather( + *[ + _post_json( + session, f"{url}/awex/report_weight_meta", init_timeout_s + ) + for url in inference_urls + ] + ), + ) + + training_params_meta = [] + for result in train_meta_resps: + meta = result.get("result", result.get("meta", result)) + if isinstance(meta, list): + training_params_meta.extend(meta) + else: + training_params_meta.append(meta) + training_params_meta = _merge_training_meta_by_name(training_params_meta) + + infer_params_meta = [] + for result in infer_meta_resps: + meta = result.get("result", result.get("meta", result)) + if isinstance(meta, list): + infer_params_meta.extend(meta) + else: + infer_params_meta.append(meta) + + kv_store.put(pair_name, "training_params_meta", training_params_meta) + kv_store.put(pair_name, "infer_params_meta", infer_params_meta) + + gateway_addr = ( + _get_own_ip() if config.host in ("0.0.0.0", "::") else config.host + ) + kv_store_url = f"http://{gateway_addr}:{config.gateway_port}" + + [master_port] = find_free_ports(1) + + init_payload_base = { + "pair_name": pair_name, + "kv_store_url": kv_store_url, + "infer_world_size": infer_world_size, + "train_world_size": train_world_size, + "num_engines": num_engines, + "master_port": master_port, + "admin_api_key": config.admin_api_key, + } + + init_tasks = [] + for i, url in enumerate(inference_urls): + init_tasks.append( + _post( + session, + f"{url}/awex/init_colocate_weight_update", + init_timeout_s, + json_data={**init_payload_base, "transfer_rank": i}, + ) + ) + for i, url in enumerate(train_urls): + init_tasks.append( + _post( + session, + f"{url}/awex/init_colocate_weight_update", + init_timeout_s, + json_data={ + **init_payload_base, + "transfer_rank": infer_world_size + i, + }, + ) + ) + await asyncio.gather(*init_tasks) + + pair_info = PairInfo( + pair_name=pair_name, + train_worker_urls=train_urls, + inference_worker_urls=inference_urls, + train_world_size=train_world_size, + inference_world_size=infer_world_size, + colocate=True, + ) + registry.register(pair_info) + + logger.info("Connected colocate pair '%s'", pair_name) + return ConnectResponse(pair_name=pair_name) + + async def _colocate_transfer_weights( + pair_info: PairInfo, + version: int, + session: aiohttp.ClientSession, + timeout_s: float, + ) -> None: + await asyncio.gather( + *[ + _post( + session, + f"{url}/awex/release_memory", + timeout_s, + json_data={"tags": ["optimizer"]}, + ) + for url in pair_info.train_worker_urls + ] + ) + + await asyncio.gather( + *[ + _post( + session, + f"{url}/awex/resume_memory", + timeout_s, + json_data={"tags": ["weights"]}, + ) + for url in pair_info.inference_worker_urls + ] + ) + + await asyncio.gather( + *[ + _post( + session, + f"{url}/awex/execute_colocate_weight_update", + timeout_s, + json_data={"version": version}, + ) + for url in pair_info.train_worker_urls + ], + *[ + _post( + session, + f"{url}/awex/execute_colocate_weight_update", + timeout_s, + json_data={"version": version}, + ) + for url in pair_info.inference_worker_urls + ], + ) + + await asyncio.gather( + *[ + _post( + session, + f"{url}/awex/release_memory", + timeout_s, + json_data={"tags": ["weights"]}, + ) + for url in pair_info.train_worker_urls + ] + ) + + await asyncio.gather( + *[ + _post( + session, + f"{url}/awex/resume_memory", + timeout_s, + json_data={"tags": ["kv_cache"]}, + ) + for url in pair_info.inference_worker_urls + ] + ) + + # Flush colocate KV keys for this version to prevent accumulation + infer_world_size = pair_info.inference_world_size + train_world_size = pair_info.train_world_size + for i in range(train_world_size): + transfer_rank = infer_world_size + i + weight_key = f"colocate_weights_rank{transfer_rank}_{version}" + done_key = f"colocate_done_rank{transfer_rank}_{version}" + kv_store.delete(pair_info.pair_name, weight_key) + kv_store.delete(pair_info.pair_name, done_key) + @asynccontextmanager async def _inference_paused( session: aiohttp.ClientSession, @@ -495,7 +708,11 @@ async def update_weights( timeout_s, pair_info.pair_name, ): - if pair_info.mode == "disk": + if pair_info.colocate: + await _colocate_transfer_weights( + pair_info, body.version, session, timeout_s + ) + elif pair_info.mode == "disk": await _disk_transfer_weights( pair_info, body.version, session, timeout_s ) diff --git a/areal/experimental/weight_update/gateway/config.py b/areal/experimental/weight_update/gateway/config.py index 4ed8b70e05..db564d8c77 100644 --- a/areal/experimental/weight_update/gateway/config.py +++ b/areal/experimental/weight_update/gateway/config.py @@ -60,6 +60,9 @@ class PairInfo: use_lora: bool = False lora_name: str = "" + # Colocated mode (training and inference share GPUs) + colocate: bool = False + def __post_init__(self): if not self.pair_name: raise ValueError("pair_name must not be empty") diff --git a/areal/experimental/weight_update/inference_adapter.py b/areal/experimental/weight_update/inference_adapter.py index 77b8319ea2..84ba633506 100644 --- a/areal/experimental/weight_update/inference_adapter.py +++ b/areal/experimental/weight_update/inference_adapter.py @@ -57,3 +57,30 @@ def batch_isend_irecv(self, **kwargs) -> None: def teardown_weight_update_group(self) -> None: """Destroy NCCL group and clear cached state.""" ... + + def init_colocate_weight_update( + self, + pair_name: str, + kv_store_url: str, + transfer_rank: int, + infer_world_size: int, + train_world_size: int, + num_engines: int, + master_port: int, + admin_api_key: str = "areal-admin-key", + timeout_s: float = 120.0, + ) -> None: + """Build device mapping, inference-only NCCL group, and colocate transport.""" + ... + + def execute_colocate_weight_update(self, version: int) -> None: + """Fetch IPC weights from KV store and apply via colocate transport.""" + ... + + def release_memory(self, tags: list[str] | None = None) -> None: + """Release GPU memory (KV cache/weights) for colocated mode.""" + ... + + def resume_memory(self, tags: list[str] | None = None) -> None: + """Resume GPU memory occupation.""" + ... diff --git a/areal/experimental/weight_update/training_adapter.py b/areal/experimental/weight_update/training_adapter.py index c2f3dcbed9..6935460416 100644 --- a/areal/experimental/weight_update/training_adapter.py +++ b/areal/experimental/weight_update/training_adapter.py @@ -57,3 +57,30 @@ def batch_isend_irecv(self, **kwargs) -> None: def teardown_weight_update_group(self) -> None: """Destroy NCCL group and clear cached state.""" ... + + def init_colocate_weight_update( + self, + pair_name: str, + kv_store_url: str, + transfer_rank: int, + infer_world_size: int, + train_world_size: int, + num_engines: int, + master_port: int, + admin_api_key: str = "areal-admin-key", + timeout_s: float = 120.0, + ) -> None: + """Register device info in KV store for colocated weight transfer.""" + ... + + def execute_colocate_weight_update(self, version: int) -> None: + """Serialize weights via IPC and put to KV store.""" + ... + + def release_memory(self, tags: list[str] | None = None) -> None: + """Release GPU memory (optimizer/weights) for colocated mode.""" + ... + + def resume_memory(self, tags: list[str] | None = None) -> None: + """Resume GPU memory occupation.""" + ... diff --git a/areal/infra/rpc/guard/app.py b/areal/infra/rpc/guard/app.py index 69a0a80f37..3e13f86dd3 100644 --- a/areal/infra/rpc/guard/app.py +++ b/areal/infra/rpc/guard/app.py @@ -590,11 +590,15 @@ def run_server( standalone guard entrypoints. Handles SIGTERM, cleanup hooks, and forked-child cleanup on shutdown. """ + import logging as _logging + from werkzeug.serving import make_server from areal.api.cli_args import NameResolveConfig from areal.utils import name_resolve, names + _logging.getLogger("werkzeug").setLevel(_logging.WARNING) + server = make_server(bind_host, port, app, threaded=True) state.server_port = server.socket.getsockname()[1] diff --git a/tests/experimental/weight_update/test_nccl_integration.py b/tests/experimental/weight_update/test_nccl_integration.py index 3213f7b75d..4b225bb203 100644 --- a/tests/experimental/weight_update/test_nccl_integration.py +++ b/tests/experimental/weight_update/test_nccl_integration.py @@ -889,3 +889,329 @@ def test_awex_megatron_dp_ep_e2e_weight_update( validate_param_names=_VALIDATE_PARAM_NAMES_MOE, init_from_scratch=True, ) + + +# --------------------------------------------------------------------------- +# Colocated weight update: Megatron + SGLang on the SAME GPUs (pure DP) +# --------------------------------------------------------------------------- + + +def _run_megatron_colocate_e2e( + *, + n_gpus: int, + pair_name: str, + tag: str, + tmp_path_factory, + model_path: str | None = None, +): + """Colocated weight transfer: MegatronEngine + SGLang share the same GPUs. + + Unlike the separated tests where inference and training each own a + disjoint half of the GPUs, colocated mode puts both on every GPU. + The LocalScheduler round-robin counter naturally wraps, giving + inference GPUs 0..N-1 and training GPUs 0..N-1 (same devices). + + Only pure DP is supported for colocated mode (TP=1, PP=1, EP=1). + """ + from areal.api import FinetuneSpec + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, + ) + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, + ) + from areal.experimental.weight_update.controller import ( + WeightUpdateController, + WeightUpdateControllerConfig, + ) + + tmp = tmp_path_factory.mktemp(tag) + model_path = model_path or _get_test_model_path() + scheduler = _make_local_scheduler(tmp, tag, gpu_devices=list(range(n_gpus))) + + # Both inference and training use ALL n_gpus GPUs (colocated). + inf_config = InferenceEngineConfig( + tokenizer_path=model_path, + backend=f"sglang:d{n_gpus}", + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.inference_service.guard", + ), + ), + consumer_batch_size=8, + max_head_offpolicyness=1024, + setup_timeout=300.0, + admin_api_key="test-admin", + ) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) + + train_config = TrainEngineConfig( + backend=f"megatron:d{n_gpus}", + experiment_name=f"test-awex-{tag}", + trial_name="t0", + path=model_path, + optimizer=OptimizerConfig(), + _version="v2", + setup_timeout=300.0, + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.training_service.guard", + env_vars=dict(NCCL_CUMEM_ENABLE="0", NCCL_NVLS_ENABLE="0"), + ), + ), + ) + train_ctrl = GatewayTrainController( + train_engine="areal.engine.megatron_engine.MegatronLMEngine", + config=train_config, + scheduler=scheduler, + ) + + wu_ctrl: WeightUpdateController | None = None + try: + # -- 1. SGLang inference (uses GPUs 0..N-1) ------------------------- + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + ) + inf_worker_urls = list(inf_ctrl._inf_addrs) + + # Randomize inference weights so the transfer is NOT a no-op. + for url in inf_worker_urls: + resp = httpx.post(f"{url}/awex/debug/randomize_parameters", timeout=120.0) + assert resp.status_code == 200, f"randomize_parameters failed: {resp.text}" + + # -- 2. Megatron training (wraps to same GPUs 0..N-1) --------------- + train_ctrl.initialize( + role="actor", + ft_spec=FinetuneSpec( + total_train_epochs=1, dataset_size=100, train_batch_size=2 + ), + ) + train_worker_urls = list(train_ctrl._worker_addrs) + + # -- 3. Weight update gateway --------------------------------------- + wu_ctrl = WeightUpdateController( + config=WeightUpdateControllerConfig(host="127.0.0.1", request_timeout=300.0) + ) + wu_ctrl.initialize() + assert wu_ctrl.health_check(), "Weight update gateway health check failed" + + # -- 4. Connect with colocate=True ---------------------------------- + wu_ctrl.connect( + pair_name=pair_name, + train_worker_urls=train_worker_urls, + inference_worker_urls=inf_worker_urls, + colocate=True, + ) + + # -- 5. Colocated weight update ------------------------------------- + result = wu_ctrl.update_weights(version=1) + assert result.status == "ok" + assert result.version == 1 + wu_ctrl.disconnect() + + # -- 6. Verify inference server still works post-update ------------- + gen_resp = httpx.post( + f"{inf_worker_urls[0]}/generate", + json={ + "text": "Hello", + "sampling_params": {"max_new_tokens": 5, "temperature": 0}, + }, + timeout=30.0, + ) + assert gen_resp.status_code == 200, ( + f"Generation failed after weight update: {gen_resp.text}" + ) + + # -- 7. Validate training ↔ inference parameter equality ------------ + _validate_weight_update_correctness_megatron( + train_worker_urls=train_worker_urls, + inf_worker_url=inf_worker_urls[0], + param_dir=tmp, + tag=tag, + ) + finally: + if wu_ctrl is not None: + wu_ctrl.destroy() + train_ctrl.destroy() + inf_ctrl.destroy() + scheduler.delete_workers(None) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize("n_gpus", [2, 4, 8], ids=["2gpu", "4gpu", "8gpu"]) +def test_awex_megatron_colocate_dp_e2e_weight_update(n_gpus, tmp_path_factory): + """Full round trip: colocated MegatronEngine (pure DP) + SGLang on same GPUs. + + Unlike separated tests that split GPUs between training and inference, + colocated mode shares all GPUs. Weight transfer uses CUDA IPC + (zero-copy on same device) instead of NCCL P2P across devices. + + Only pure DP (TP=1, PP=1, EP=1) is supported for colocated mode. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + _run_megatron_colocate_e2e( + n_gpus=n_gpus, + pair_name=f"test_megatron_colocate_dp{n_gpus}", + tag=f"megatron_colocate_dp{n_gpus}", + tmp_path_factory=tmp_path_factory, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +def test_awex_megatron_colocate_dp_multi_version_e2e(tmp_path_factory): + """Colocated weight update with multiple sequential versions. + + Verifies that the colocated IPC path correctly handles version + sequencing: version 1 → version 2. The KV store keys include + the version number, so each round must produce fresh IPC handles. + """ + n_gpus = 2 + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + from areal.api import FinetuneSpec + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, + ) + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, + ) + from areal.experimental.weight_update.controller import ( + WeightUpdateController, + WeightUpdateControllerConfig, + ) + + tag = "megatron_colocate_multi_ver" + tmp = tmp_path_factory.mktemp(tag) + model_path = _get_test_model_path() + scheduler = _make_local_scheduler(tmp, tag, gpu_devices=list(range(n_gpus))) + + inf_config = InferenceEngineConfig( + tokenizer_path=model_path, + backend=f"sglang:d{n_gpus}", + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.inference_service.guard", + ), + ), + consumer_batch_size=8, + max_head_offpolicyness=1024, + setup_timeout=300.0, + admin_api_key="test-admin", + ) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) + + train_config = TrainEngineConfig( + backend=f"megatron:d{n_gpus}", + experiment_name=f"test-awex-{tag}", + trial_name="t0", + path=model_path, + optimizer=OptimizerConfig(), + _version="v2", + setup_timeout=300.0, + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.training_service.guard", + env_vars=dict(NCCL_CUMEM_ENABLE="0", NCCL_NVLS_ENABLE="0"), + ), + ), + ) + train_ctrl = GatewayTrainController( + train_engine="areal.engine.megatron_engine.MegatronLMEngine", + config=train_config, + scheduler=scheduler, + ) + + wu_ctrl: WeightUpdateController | None = None + try: + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + ) + inf_worker_urls = list(inf_ctrl._inf_addrs) + + for url in inf_worker_urls: + resp = httpx.post(f"{url}/awex/debug/randomize_parameters", timeout=120.0) + assert resp.status_code == 200, f"randomize_parameters failed: {resp.text}" + + train_ctrl.initialize( + role="actor", + ft_spec=FinetuneSpec( + total_train_epochs=1, dataset_size=100, train_batch_size=2 + ), + ) + train_worker_urls = list(train_ctrl._worker_addrs) + + wu_ctrl = WeightUpdateController( + config=WeightUpdateControllerConfig(host="127.0.0.1", request_timeout=300.0) + ) + wu_ctrl.initialize() + assert wu_ctrl.health_check() + + wu_ctrl.connect( + pair_name="test_colocate_multi_ver", + train_worker_urls=train_worker_urls, + inference_worker_urls=inf_worker_urls, + colocate=True, + ) + + # Version 1 + result1 = wu_ctrl.update_weights(version=1) + assert result1.status == "ok" + assert result1.version == 1 + + # Version 2 + result2 = wu_ctrl.update_weights(version=2) + assert result2.status == "ok" + assert result2.version == 2 + + wu_ctrl.disconnect() + + # Verify inference still works after two sequential updates + gen_resp = httpx.post( + f"{inf_worker_urls[0]}/generate", + json={ + "text": "Hello", + "sampling_params": {"max_new_tokens": 5, "temperature": 0}, + }, + timeout=30.0, + ) + assert gen_resp.status_code == 200, ( + f"Generation failed after weight updates: {gen_resp.text}" + ) + + # Final parameter equality check + _validate_weight_update_correctness_megatron( + train_worker_urls=train_worker_urls, + inf_worker_url=inf_worker_urls[0], + param_dir=tmp, + tag=tag, + ) + finally: + if wu_ctrl is not None: + wu_ctrl.destroy() + train_ctrl.destroy() + inf_ctrl.destroy() + scheduler.delete_workers(None)