From 4606a459cde1614066d2362649d5f2733839a98b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=9A=E6=83=9F?= Date: Mon, 11 May 2026 15:16:19 +0800 Subject: [PATCH 1/2] feat: enable v2 training pipeline with controller parity Bring GatewayTrainController and RolloutControllerV2 to full parity with v1 controllers for RL training paths. Key changes: - Route to RolloutControllerV2 when config._version=="v2" - Add version management, connect_engine, clear_batches to GatewayTrainController - Simplify AsyncRewardWrapper lifecycle with atexit shutdown - Unify HTTP client sessions across inference/training controllers - Switch default workflow to MathAgent in example configs - Add agent config section to all example YAML files - Remove obsolete get_custom_reward_fn from reward module - Add async reward wrapper tests --- areal/api/reward_api.py | 164 ++++------ areal/engine/sglang_remote.py | 10 +- areal/engine/vllm_remote.py | 10 +- .../controller/controller.py | 300 ++++++------------ .../inference_service/controller/workflow.py | 31 +- .../inference_service/data_proxy/app.py | 34 +- .../inference_service/data_proxy/session.py | 2 +- .../inference_service/inf_bridge.py | 4 +- .../training_service/controller/controller.py | 207 +++++++++++- .../weight_update/awex/fsdp_adapter.py | 8 + .../weight_update/controller/controller.py | 23 +- .../experimental/weight_update/gateway/app.py | 57 +--- areal/infra/controller/rollout_controller.py | 1 + areal/infra/workflow_context.py | 16 +- areal/reward/__init__.py | 20 -- areal/trainer/rl_trainer.py | 31 +- areal/workflow/openai/math_agent.py | 19 +- examples/math/boba_grpo.yaml | 5 + examples/math/gsm8k_dapo_dynamic_bs.yaml | 5 + examples/math/gsm8k_drgrpo.yaml | 5 + examples/math/gsm8k_grpo.yaml | 7 +- examples/math/gsm8k_grpo_cpu.yaml | 5 + examples/math/gsm8k_grpo_lora.yaml | 5 + examples/math/gsm8k_grpo_megatron.yaml | 5 + examples/math/gsm8k_grpo_megatron_fp8.yaml | 5 + examples/math/gsm8k_grpo_megatron_lora.yaml | 5 + .../math/gsm8k_grpo_megatron_lora_moe.yaml | 5 + examples/math/gsm8k_grpo_npu.yaml | 5 + examples/math/gsm8k_gspo.yaml | 5 + examples/math/gsm8k_liteppo.yaml | 5 + examples/math/gsm8k_m2po.yaml | 5 + examples/math/gsm8k_ppo.yaml | 5 + examples/math/gsm8k_ppo_megatron.yaml | 5 + examples/math/gsm8k_reinforce.yaml | 5 + examples/math/gsm8k_reinforce_baseline.yaml | 5 + examples/math/gsm8k_rl.py | 14 +- examples/math/gsm8k_rloo.yaml | 5 + examples/math/gsm8k_sapo.yaml | 5 + examples/openclaw/config.yaml | 1 + pyproject.toml | 1 + pyproject.vllm.toml | 1 + .../inference_service/test_controller.py | 37 --- .../test_controller_version.py | 17 +- tests/test_async_reward_wrapper.py | 165 ++++++++++ tests/test_examples.py | 7 +- uv.lock | 1 + uv.vllm.lock | 1 + 47 files changed, 776 insertions(+), 508 deletions(-) create mode 100644 tests/test_async_reward_wrapper.py diff --git a/areal/api/reward_api.py b/areal/api/reward_api.py index 3cf32984ea..44bef1f692 100644 --- a/areal/api/reward_api.py +++ b/areal/api/reward_api.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import atexit import os import threading -import traceback -import weakref from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool @@ -61,15 +60,18 @@ def reward_fn( class AsyncRewardWrapper: - """ - Wraps a synchronous reward function to make it async with timeout handling. - Automatically manages ProcessPoolExecutor lifecycle based on instance count. + """Wraps a synchronous reward function for async execution with timeout and retries. + + Executors are shared by ``max_workers`` key and cleaned up via ``atexit``. Includes automatic recovery from broken process pools. + + The reward function and its arguments must be picklable since they + are dispatched to worker processes via ``ProcessPoolExecutor``. """ - _executors = {} - _instance_counts = {} + _executors: dict[int, ProcessPoolExecutor] = {} _lock = threading.Lock() + _atexit_registered = False def __init__( self, @@ -83,8 +85,6 @@ def __init__( if max_workers is None: cpu_count = os.cpu_count() or 1 device_count = _get_device_count_safely() - # Heuristic for max_workers: distribute CPU cores across devices, - # then halve to be conservative, ensuring at least one worker. max_workers = max((cpu_count // device_count) // 2, 1) self.max_workers = max_workers self.max_retries = max_retries @@ -95,115 +95,89 @@ def __init__( self._executors[self._executor_key] = ProcessPoolExecutor( max_workers=max_workers ) - self._instance_counts[self._executor_key] = 0 - self._instance_counts[self._executor_key] += 1 - - weakref.finalize(self, AsyncRewardWrapper._cleanup_executor, max_workers) + if not AsyncRewardWrapper._atexit_registered: + atexit.register(AsyncRewardWrapper._atexit_shutdown_all) + AsyncRewardWrapper._atexit_registered = True @classmethod - def _cleanup_executor(cls, executor_key): - """Called when an AsyncRewardWrapper instance is garbage collected""" + def _atexit_shutdown_all(cls): + """Shut down all executors before ``_python_exit`` to prevent + worker processes deadlocking in ``_finalize_join``. + Must use ``wait=True`` so the result queue is fully drained. + """ with cls._lock: - if executor_key in cls._instance_counts: - cls._instance_counts[executor_key] -= 1 - if cls._instance_counts[executor_key] <= 0: - if executor_key in cls._executors: - executor = cls._executors.pop(executor_key) - executor.shutdown(wait=False, cancel_futures=True) - logger.debug( - f"ProcessPoolExecutor with {executor_key} workers shut down" - ) - cls._instance_counts.pop(executor_key, None) - - @classmethod - def _recreate_executor(cls, executor_key, max_workers): - """Recreate a broken ProcessPoolExecutor""" - with cls._lock: - if executor_key in cls._executors: - # Clean up the broken executor - old_executor = cls._executors[executor_key] + for executor in cls._executors.values(): try: - old_executor.shutdown(wait=False) + executor.shutdown(wait=True, cancel_futures=True) except Exception as e: - logger.warning(f"Error shutting down broken executor: {e}") + logger.warning(f"Error shutting down executor at exit: {e}") + cls._executors.clear() - # Create a new executor - cls._executors[executor_key] = ProcessPoolExecutor( - max_workers=max_workers - ) - logger.info(f"Recreated ProcessPoolExecutor with {max_workers} workers") - return cls._executors[executor_key] - return None + @classmethod + def _recreate_executor( + cls, + executor_key: int, + max_workers: int, + broken: ProcessPoolExecutor, + ) -> ProcessPoolExecutor | None: + with cls._lock: + current = cls._executors.get(executor_key) + if current is not broken: + return current + try: + broken.shutdown(wait=False) + except Exception as e: + logger.warning(f"Error shutting down broken executor: {e}") + try: + new_executor = ProcessPoolExecutor(max_workers=max_workers) + except Exception: + logger.exception("Failed to create replacement ProcessPoolExecutor") + cls._executors.pop(executor_key, None) + return None + cls._executors[executor_key] = new_executor + logger.info(f"Recreated ProcessPoolExecutor with {max_workers} workers") + return new_executor async def __call__(self, *args, **kwargs) -> float: - last_exception = None - for attempt in range(self.max_retries + 1): - with self._lock: - executor = self._executors.get(self._executor_key) - + executor = self._executors.get(self._executor_key) if executor is None: raise RuntimeError("ProcessPoolExecutor has been shut down") - loop = asyncio.get_event_loop() + is_last = attempt == self.max_retries try: - future = loop.run_in_executor( + future = asyncio.get_running_loop().run_in_executor( executor, partial(self.reward_fn, *args, **kwargs), ) - reward = await asyncio.wait_for( - future, - timeout=self.timeout_seconds, - ) - return reward + return await asyncio.wait_for(future, timeout=self.timeout_seconds) except TimeoutError: - last_exception = TimeoutError( - f"Reward computation timed out after {self.timeout_seconds}s" - ) logger.warning( f"Computing reward timeout after {self.timeout_seconds}s " f"(attempt {attempt + 1}/{self.max_retries + 1}). " - f"{'Retrying...' if attempt < self.max_retries else 'Returning 0.'}" + f"{'Returning 0.' if is_last else 'Retrying...'}" ) - if attempt < self.max_retries: - continue - return 0 - except BrokenProcessPool as e: - last_exception = e + if is_last: + return 0 + except BrokenProcessPool: logger.warning( f"ProcessPoolExecutor broken (attempt {attempt + 1}/{self.max_retries + 1}). " "Attempting to recreate..." ) - if attempt < self.max_retries: - # Try to recreate the executor - new_executor = self._recreate_executor( - self._executor_key, self.max_workers - ) - if new_executor is None: - logger.error("Failed to recreate ProcessPoolExecutor") - break - # Continue to next attempt - continue - else: - logger.error("Max retries exceeded for BrokenProcessPool.") - traceback.print_exc() - raise e - except Exception as e: - last_exception = e - logger.error(f"Unexpected error in reward computation: {e}") - if attempt < self.max_retries: - logger.info( - f"Retrying... (attempt {attempt + 1}/{self.max_retries + 1})" + if is_last: + raise + if ( + self._recreate_executor( + self._executor_key, self.max_workers, executor ) - continue - else: - logger.error("Max retries exceeded for unexpected error.") - traceback.print_exc() - raise e - - # If we get here, all retries failed - if last_exception: - traceback.print_exc() - raise last_exception - else: - raise RuntimeError("Reward computation failed after all retries.") + is None + ): + raise + except Exception: + logger.exception( + f"Reward computation error (attempt {attempt + 1}/{self.max_retries + 1})" + ) + if is_last: + raise + + return 0 diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 52039df326..566eef3157 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -530,9 +530,13 @@ def export_stats(self) -> dict[str, float]: return stats_tracker.export_all(reduce_group=None) @classmethod - def as_controller( - cls, config: InferenceEngineConfig, scheduler: Scheduler - ) -> RolloutController: + def as_controller(cls, config: InferenceEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + + return RolloutControllerV2(config=config, scheduler=scheduler) return RolloutController(cls, config=config, scheduler=scheduler) def clear_batches(self, shard_ids: list[str]) -> None: diff --git a/areal/engine/vllm_remote.py b/areal/engine/vllm_remote.py index 5db6cc8f88..6a89446e39 100644 --- a/areal/engine/vllm_remote.py +++ b/areal/engine/vllm_remote.py @@ -493,9 +493,13 @@ def export_stats(self) -> dict[str, float]: return stats_tracker.export_all(reduce_group=None) @classmethod - def as_controller( - cls, config: InferenceEngineConfig, scheduler: Scheduler - ) -> RolloutController: + def as_controller(cls, config: InferenceEngineConfig, scheduler: Scheduler): + if config._version == "v2": + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + + return RolloutControllerV2(config=config, scheduler=scheduler) return RolloutController(cls, config=config, scheduler=scheduler) def clear_batches(self, shard_ids: list[str]) -> None: diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index cd2525f512..99882de956 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -12,7 +12,6 @@ import asyncio import concurrent.futures -import copy import os import sys import threading @@ -28,12 +27,13 @@ import httpx from openai.types.chat import ChatCompletion, ChatCompletionChunk +from areal.infra.utils.http import async_http_retry, create_httpx_client + if TYPE_CHECKING: from areal.api.scheduler_api import Scheduler, Worker from areal.api.cli_args import InferenceEngineConfig from areal.api.io_struct import LocalInfServerInfo -from areal.infra.utils.http import async_http_retry, create_httpx_client from areal.utils import logging from areal.utils.network import format_hostport @@ -126,7 +126,7 @@ def __init__( # Worker management self.workers: list[Worker] = [] - self.server_infos: list[LocalInfServerInfo] = [] + self._server_infos: list[LocalInfServerInfo] = [] self._worker_role: str = "" # Addresses resolved after initialization @@ -257,18 +257,8 @@ def _bg_initialize( return from areal.infra.remote_inf_engine import RemoteInfEngine - from areal.infra.workflow_executor import WorkflowExecutor - - self._workflow_executor = WorkflowExecutor( - config=cast(InferenceEngineConfig, self.config), - inference_engine=cast(RemoteInfEngine, self), - ) - self._workflow_executor.initialize() - - if self._shutdown_requested.is_set(): - return - from areal.infra.staleness_manager import StalenessManager + from areal.infra.workflow_executor import WorkflowExecutor max_concurrent = ( self.config.max_concurrent_rollouts or self.config.consumer_batch_size @@ -280,6 +270,16 @@ def _bg_initialize( max_staleness=self.config.max_head_offpolicyness, ) + self._workflow_executor = WorkflowExecutor( + config=self.config, + inference_engine=cast(RemoteInfEngine, self), + staleness_manager=self._staleness_manager, + ) + self._workflow_executor.initialize() + + if self._shutdown_requested.is_set(): + return + logger.info("RolloutControllerV2 initialized (role=%s)", self._worker_role) if self.config.model: @@ -417,7 +417,7 @@ async def _async_initialize( if self.external_mode: logger.info("External mode — skipping inference server launch") elif server_infos is not None: - self.server_infos = server_infos + self._server_infos = server_infos self._inf_addrs = [ f"http://{format_hostport(info.host, info.port)}" for info in server_infos @@ -554,82 +554,14 @@ async def _async_fork_inf_servers( nnodes_per_instance: int, server_args: dict[str, Any] | None, ) -> None: - """Fork inference server groups in parallel across all DP ranks.""" - tp_size = alloc.parallel.tp_size - pp_size = alloc.parallel.pp_size - - # Build backend-specific launch command builder if inf_backend == "sglang": from areal.api.cli_args import SGLangConfig - sglang_config = SGLangConfig() - if server_args: - sglang_config = copy.deepcopy(sglang_config) - for k, v in server_args.items(): - if hasattr(sglang_config, k): - setattr(sglang_config, k, v) - else: - logger.warning( - "SGLangConfig has no attribute %r, ignoring " - "server_args entry (value=%r)", - k, - v, - ) - - def _build_launch_cmd( - host: str | None, - port: int | None, - n_nodes: int = 1, - node_rank: int = 0, - dist_init_addr: str | None = None, - ) -> list[str]: - return SGLangConfig.build_cmd( - sglang_config=sglang_config, - tp_size=tp_size, - pp_size=pp_size, - base_gpu_id=0, - host=host, - port=port, - dist_init_addr=dist_init_addr, - n_nodes=n_nodes, - node_rank=node_rank, - ) - + _build_launch_cmd = SGLangConfig.build_cmd_from_args elif inf_backend == "vllm": from areal.api.cli_args import vLLMConfig - vllm_config = vLLMConfig() - if server_args: - vllm_config = copy.deepcopy(vllm_config) - for k, v in server_args.items(): - if hasattr(vllm_config, k): - setattr(vllm_config, k, v) - else: - logger.warning( - "vLLMConfig has no attribute %r, ignoring " - "server_args entry (value=%r)", - k, - v, - ) - - def _build_launch_cmd( - host: str | None, - port: int | None, - n_nodes: int = 1, - node_rank: int = 0, - dist_init_addr: str | None = None, - ) -> list[str]: - return vLLMConfig.build_cmd( - vllm_config=vllm_config, - tp_size=tp_size, - pp_size=pp_size, - host=host, - port=port, - dist_init_addr=dist_init_addr, - n_nodes=n_nodes, - node_rank=node_rank, - ) - + _build_launch_cmd = vLLMConfig.build_cmd_from_args else: raise ValueError(f"Unsupported inference backend: {inf_backend!r}") @@ -671,13 +603,14 @@ async def _fork_node(node_rank: int, worker: Any) -> tuple[str, int, str]: inf_host: str = port_data["host"] inf_port: int = port_data["ports"][0] - cmd = _build_launch_cmd( + server_args.update( host=inf_host, port=inf_port, - n_nodes=nnodes_per_instance, + nnodes=nnodes_per_instance, node_rank=node_rank, dist_init_addr=dist_init_addr, ) + cmd = _build_launch_cmd(server_args) fork_payload: dict[str, Any] = { "role": "inf-server", @@ -728,7 +661,7 @@ async def _fork_node(node_rank: int, worker: Any) -> tuple[str, int, str]: for host, port, forked in group_results: addr = f"http://{format_hostport(host, port)}" self._inf_addrs.append(addr) - self.server_infos.append( + self._server_infos.append( LocalInfServerInfo( host=host, port=port, @@ -1063,7 +996,7 @@ def destroy(self) -> None: self._service_roles.clear() self.workers.clear() - self.server_infos.clear() + self._server_infos.clear() with self._online_waiters_lock: for waiter in self._online_waiters: if not waiter.future.done(): @@ -1097,8 +1030,8 @@ async def _async_set_version(self, version: int) -> None: payload = {"version": version} results = await asyncio.gather( *[ - self._async_gateway_http_post(f"/set_version/{wid}", payload) - for wid in self._worker_ids.values() + self._async_data_proxy_post(addr, "/set_version", payload) + for addr in self._data_proxy_addrs ], return_exceptions=True, ) @@ -1134,6 +1067,7 @@ def submit( is_eval: bool = False, group_size: int = 1, ) -> int: + self._ensure_initialized() resolved_workflow = self._resolve_workflow( workflow, workflow_kwargs, @@ -1154,6 +1088,7 @@ def wait( timeout: float | None = None, raise_timeout: bool = True, ) -> list[dict[str, Any] | None]: + self._ensure_initialized() return self.workflow_executor.wait( count, timeout=timeout, raise_timeout=raise_timeout ) @@ -1416,21 +1351,15 @@ async def _stream_chat_completion( def pause(self) -> None: """Pause dispatcher + pause all workers.""" - from areal.infra.utils.concurrent import run_async_task - self._ensure_initialized() - if self._workflow_executor is not None: - self._workflow_executor.pause() - run_async_task(self.pause_generation) + assert self._workflow_executor is not None + self._workflow_executor.pause() def resume(self) -> None: """Resume all workers + resume dispatcher.""" - from areal.infra.utils.concurrent import run_async_task - self._ensure_initialized() - run_async_task(self.continue_generation) - if self._workflow_executor is not None: - self._workflow_executor.resume() + assert self._workflow_executor is not None + self._workflow_executor.resume() def offload(self) -> None: """Offload model memory on all inference workers.""" @@ -1440,12 +1369,12 @@ def offload(self) -> None: run_async_task(self._async_offload) async def _async_offload(self) -> None: - if not self._gateway_addr: + if not self._data_proxy_addrs: return results = await asyncio.gather( *( - self._async_gateway_http_post(f"/release_memory_occupation/{wid}", {}) - for wid in self._worker_ids.values() + self._async_data_proxy_post(addr, "/release_memory_occupation", {}) + for addr in self._data_proxy_addrs ), return_exceptions=True, ) @@ -1463,15 +1392,13 @@ def onload(self, tags: list[str] | None = None) -> None: run_async_task(self._async_onload, tags) async def _async_onload(self, tags: list[str] | None = None) -> None: - if not self._gateway_addr: + if not self._data_proxy_addrs: return payload: dict = {"tags": tags} if tags is not None else {} results = await asyncio.gather( *( - self._async_gateway_http_post( - f"/resume_memory_occupation/{wid}", payload - ) - for wid in self._worker_ids.values() + self._async_data_proxy_post(addr, "/resume_memory_occupation", payload) + for addr in self._data_proxy_addrs ), return_exceptions=True, ) @@ -1481,49 +1408,53 @@ async def _async_onload(self, tags: list[str] | None = None) -> None: if failed and len(failed) == len(results): raise RuntimeError(f"onload failed on ALL {len(failed)} workers") - async def pause_generation(self, worker_id: str | None = None) -> None: - """Pause generation on a specific worker, or all workers if worker_id is None.""" - if not self._gateway_addr: + def pause_generation(self) -> None: + """Pause generation on all workers.""" + from areal.infra.utils.concurrent import run_async_task + + self._ensure_initialized() + run_async_task(self._async_pause_generation) + + async def _async_pause_generation(self) -> None: + if not self._data_proxy_addrs: return - if worker_id is not None: - await self._async_gateway_http_post(f"/pause_generation/{worker_id}", {}) - else: - results = await asyncio.gather( - *[ - self._async_gateway_http_post(f"/pause_generation/{wid}", {}) - for wid in self._worker_ids.values() - ], - return_exceptions=True, - ) - failed = [r for r in results if isinstance(r, Exception)] - for r in failed: - logger.error("Failed to pause generation on a worker: %s", r) - if failed and len(failed) == len(results): - raise RuntimeError( - f"pause_generation failed on ALL {len(failed)} workers" - ) + results = await asyncio.gather( + *[ + self._async_data_proxy_post(addr, "/pause_generation", {}) + for addr in self._data_proxy_addrs + ], + return_exceptions=True, + ) + failed = [r for r in results if isinstance(r, Exception)] + for r in failed: + logger.error("Failed to pause generation on a worker: %s", r) + if failed and len(failed) == len(results): + raise RuntimeError(f"pause_generation failed on ALL {len(failed)} workers") - async def continue_generation(self, worker_id: str | None = None) -> None: - """Continue generation on a specific worker, or all workers if worker_id is None.""" - if not self._gateway_addr: + def continue_generation(self) -> None: + """Continue generation on all workers.""" + from areal.infra.utils.concurrent import run_async_task + + self._ensure_initialized() + run_async_task(self._async_continue_generation) + + async def _async_continue_generation(self) -> None: + if not self._data_proxy_addrs: return - if worker_id is not None: - await self._async_gateway_http_post(f"/continue_generation/{worker_id}", {}) - else: - results = await asyncio.gather( - *[ - self._async_gateway_http_post(f"/continue_generation/{wid}", {}) - for wid in self._worker_ids.values() - ], - return_exceptions=True, + results = await asyncio.gather( + *[ + self._async_data_proxy_post(addr, "/continue_generation", {}) + for addr in self._data_proxy_addrs + ], + return_exceptions=True, + ) + failed = [r for r in results if isinstance(r, Exception)] + for r in failed: + logger.error("Failed to continue generation on a worker: %s", r) + if failed and len(failed) == len(results): + raise RuntimeError( + f"continue_generation failed on ALL {len(failed)} workers" ) - failed = [r for r in results if isinstance(r, Exception)] - for r in failed: - logger.error("Failed to continue generation on a worker: %s", r) - if failed and len(failed) == len(results): - raise RuntimeError( - f"continue_generation failed on ALL {len(failed)} workers" - ) # -- Stats ------------------------------------------------------------- @@ -1539,12 +1470,6 @@ def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None # -- Proxy compatibility (gateway IS the proxy) ------------------------ - def start_proxy(self) -> None: - """No-op — gateway already acts as the proxy.""" - - def start_proxy_gateway(self) -> None: - """No-op — gateway already acts as the proxy gateway.""" - @property def proxy_gateway_addr(self) -> str: self._ensure_initialized() @@ -1552,6 +1477,16 @@ def proxy_gateway_addr(self) -> str: # -- Properties -------------------------------------------------------- + @property + def inference_worker_urls(self) -> list[str]: + self._ensure_initialized() + return list(self._inf_addrs) + + @property + def server_infos(self) -> list[LocalInfServerInfo]: + self._ensure_initialized() + return self._server_infos + @property def worker_ids(self) -> dict[str, str]: """Return mapping from data proxy address to router-assigned worker_id.""" @@ -1570,14 +1505,6 @@ def workflow_executor(self): raise RuntimeError("RolloutControllerV2.initialize() must be called first") return self._workflow_executor - @property - def dispatcher(self): - return self.workflow_executor.dispatcher - - @property - def runner(self): - return self.dispatcher.runner - # -- Workflow resolution helpers ---------------------------------------- def _wrap_agent(self, agent: Any, group_size: int = 1): @@ -1834,30 +1761,6 @@ def _kill_forked_service( "Error killing forked service %s/%d: %s", role, worker_index, exc ) - def _gateway_http_post(self, endpoint: str, payload: dict[str, Any]) -> None: - """Make a synchronous HTTP POST to the gateway with admin auth. - - Use ``_async_gateway_http_post`` from async contexts to avoid blocking - the event loop. - - Raises ``RuntimeError`` on HTTP errors or connection failures so that - callers (e.g. ``pause()`` / ``resume()``) can detect and handle them. - """ - url = f"{self._gateway_addr}{endpoint}" - try: - resp = self._sync_client.post( - url, - json=payload, - headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, - timeout=self.config.request_timeout, - ) - if resp.status_code >= 400: - raise RuntimeError( - f"Gateway {endpoint} returned {resp.status_code}: {resp.text}" - ) - except httpx.HTTPError as exc: - raise RuntimeError(f"Failed to POST {endpoint}: {exc}") from exc - async def _get_async_client(self) -> httpx.AsyncClient: """Return the shared async HTTP client, recreating it when the event loop changes. @@ -1881,31 +1784,6 @@ async def _get_async_client(self) -> httpx.AsyncClient: pass return self._async_client - @async_http_retry - async def _async_gateway_http_post( - self, endpoint: str, payload: dict[str, Any] - ) -> None: - """Make a non-blocking HTTP POST to the gateway with admin auth. - - Raises ``RuntimeError`` on HTTP errors or connection failures so that - callers (e.g. ``pause_generation()`` / ``continue_generation()``) can - detect and handle them. - """ - url = f"{self._gateway_addr}{endpoint}" - try: - client = await self._get_async_client() - resp = await client.post( - url, - json=payload, - headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, - ) - if resp.status_code >= 400: - raise RuntimeError( - f"Gateway {endpoint} returned {resp.status_code}: {resp.text}" - ) - except httpx.HTTPError as exc: - raise RuntimeError(f"Failed to POST {endpoint}: {exc}") from exc - @async_http_retry async def _async_data_proxy_post( self, addr: str, endpoint: str, payload: dict[str, Any] diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index bdf1304ef4..4ab4345b26 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -10,8 +10,8 @@ import openai from areal.api.workflow_api import RolloutWorkflow -from areal.experimental.openai.proxy.server import deserialize_interactions from areal.infra import workflow_context +from areal.infra.rpc.serialization import deserialize_value from areal.infra.utils.http import async_http_retry from areal.utils import logging, stats_tracker @@ -105,7 +105,7 @@ async def _export_interactions( session_ids: list[str], group_id: str | None = None, trajectory_id: int | None = None, - ) -> dict[str, InteractionWithTokenLogpReward]: + ) -> dict[str, Any]: url = f"{self.gateway_addr}/{_EXPORT_TRAJECTORIES_PATHNAME}" headers = {"Authorization": f"Bearer {self._admin_api_key}"} payload: dict[str, Any] = { @@ -120,7 +120,7 @@ async def _export_interactions( resp.raise_for_status() data = await resp.json() - return deserialize_interactions(data["interactions"]) + return deserialize_value(data["traj"]) async def arun_episode( self, @@ -198,23 +198,19 @@ async def _run_one(session_id: str, session_api_key: str) -> float: ) session_ids = [sid for sid, _ in sessions] - interactions = await self._export_interactions( + traj = await self._export_interactions( http_session, session_ids, group_id=group_id, ) - if not interactions: - logger.warning( - "Group %s has no interactions, all trajectories rejected.", - group_id, - ) + if not traj: return None tracker = stats_tracker.get(workflow_context.stat_scope()) for r in rewards: tracker.scalar(reward=r) - return interactions + return traj async def _run_online( self, @@ -227,15 +223,20 @@ async def _run_online( if not export_request: return None - interactions = await self._export_interactions( + traj = await self._export_interactions( http_session, [export_request["session_id"]], trajectory_id=export_request["trajectory_id"], ) - if not interactions: + if not traj: + return None + + if "rewards" not in traj or not traj["rewards"]: + logger.warning( + "Exported trajectory is missing rewards. This trajectory will be rejected." + ) return None - last_id = next(reversed(interactions)) - last_reward = interactions[last_id].reward + last_reward = float(traj["rewards"][-1]) stats_tracker.get(workflow_context.stat_scope()).scalar(reward=last_reward) - return interactions + return traj diff --git a/areal/experimental/inference_service/data_proxy/app.py b/areal/experimental/inference_service/data_proxy/app.py index 43fc70e2bf..fd2794983c 100644 --- a/areal/experimental/inference_service/data_proxy/app.py +++ b/areal/experimental/inference_service/data_proxy/app.py @@ -37,12 +37,18 @@ from areal.experimental.inference_service.sglang.bridge import SGLangBridgeBackend from areal.experimental.inference_service.vllm.bridge import VLLMBridgeBackend from areal.experimental.openai.client import ArealOpenAI -from areal.experimental.openai.proxy.server import serialize_interactions +from areal.experimental.openai.types import ( + InteractionWithTokenLogpReward, + concat_string_interactions, +) from areal.infra.rpc.guard.data_blueprint import ( data_bp, ) +from areal.infra.rpc.rtensor import RTensor +from areal.infra.rpc.serialization import serialize_value from areal.infra.utils.http import create_httpx_client from areal.utils import logging +from areal.utils.data import concat_padded_tensors logger = logging.getLogger("InferenceDataProxy") @@ -438,6 +444,11 @@ async def set_version(request: Request): if version is None or not isinstance(version, int): raise HTTPException(status_code=400, detail="'version' (int) is required") app.state.version = version + + # Propagate version to InfBridge so it stamps correct versions on generated tokens + if app.state.inf_bridge is not None: + app.state.inf_bridge.set_version(version) + return SetVersionResponse(status="ok", version=version) @app.get("/get_version", response_model=GetVersionResponse) @@ -712,9 +723,6 @@ async def export_trajectories( detail="session_ids must be a non-empty list", ) - from areal.experimental.openai.types import InteractionWithTokenLogpReward - from areal.infra.rpc.rtensor import RTensor - merged: dict[str, InteractionWithTokenLogpReward] = {} for sid in body.session_ids: @@ -728,24 +736,22 @@ async def export_trajectories( style=body.style, trajectory_id=body.trajectory_id, ) + merged.update(interactions) except KeyError: continue - for item in interactions.values(): - if item.has_tensor_data: - item.to_tensor_dict() - item._cache = RTensor.remotize( - item._cache, node_addr=config.serving_addr - ) - - merged.update(interactions) + if all(v.has_tensor_data for v in merged.values()): + traj = concat_padded_tensors([v.to_tensor_dict() for v in merged.values()]) + traj = RTensor.remotize(traj, node_addr=config.serving_addr) + else: + traj = concat_string_interactions(merged) if body.remove_session: for sid in body.session_ids: store.remove_session(sid) - serialized = serialize_interactions(merged) - return ExportTrajectoriesResponse(interactions=serialized) + serialized = serialize_value(traj) + return ExportTrajectoriesResponse(traj=serialized) # ========================================================================= # Runtime backend reconfiguration (for fork-based deployment) diff --git a/areal/experimental/inference_service/data_proxy/session.py b/areal/experimental/inference_service/data_proxy/session.py index d36ae9ca81..e6cbe6c2d2 100644 --- a/areal/experimental/inference_service/data_proxy/session.py +++ b/areal/experimental/inference_service/data_proxy/session.py @@ -80,7 +80,7 @@ class ExportTrajectoriesRequest(BaseModel): class ExportTrajectoriesResponse(BaseModel): """Response containing merged serialized interactions.""" - interactions: Any + traj: dict[str, Any] @dataclass(frozen=True) diff --git a/areal/experimental/inference_service/inf_bridge.py b/areal/experimental/inference_service/inf_bridge.py index 5de42d8bbf..f7532941f2 100644 --- a/areal/experimental/inference_service/inf_bridge.py +++ b/areal/experimental/inference_service/inf_bridge.py @@ -191,6 +191,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: accumulated_tokens: list[int] = [] accumulated_logprobs: list[float] = [] + accumulated_versions: list[int] = [] stop_reason: _StopReason | None = None final_routed_experts: np.ndarray | None = None @@ -220,6 +221,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: accumulated_tokens.extend(result.output_tokens) accumulated_logprobs.extend(result.output_logprobs) + accumulated_versions.extend([self._version] * len(result.output_tokens)) stop_reason = cast(_StopReason, result.stop_reason) if result.routed_experts is not None: @@ -254,7 +256,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: input_tokens=list(req.input_ids), output_tokens=accumulated_tokens, output_logprobs=accumulated_logprobs, - output_versions=[self._version] * len(accumulated_tokens), + output_versions=accumulated_versions, stop_reason=stop_reason, tokenizer=req.tokenizer, latency=latency, diff --git a/areal/experimental/training_service/controller/controller.py b/areal/experimental/training_service/controller/controller.py index 8a498630f6..b6adc58825 100644 --- a/areal/experimental/training_service/controller/controller.py +++ b/areal/experimental/training_service/controller/controller.py @@ -8,6 +8,7 @@ import threading import time import traceback +from threading import Lock from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -30,11 +31,6 @@ class GatewayTrainController: _GUARD_SUFFIX = "-guard" - # TODO(agent): Controller v2 is not yet a drop-in replacement for - # TrainController on PPO/GRPO paths. Add parity for connect_engine, - # prepare_batch/rollout_batch, and update_weights (plus the matching - # gateway/data-proxy/worker endpoints), or keep RL controllers on v1. - def __init__( self, train_engine: type[TrainEngine] | str, @@ -52,11 +48,20 @@ def __init__( self._router_addr: str = "" self._model_addr: str = "" self._worker_addrs: list[str] = [] + self._guard_addrs: list[str] = [] self._forked_services: list[tuple[str, str, int]] = [] self._service_roles: list[str] = [] self._role: str = "" self._parallel_strategy = self.train_alloc.parallel self._own_process_group = False + self.rollout: Any | None = None + self._weight_update_ctrl: Any | None = None + + # Version management + self._version_lock = Lock() + self._version = 0 + + # Shared HTTP client (lazy, per-event-loop) self._async_client: Any | None = None self._async_client_loop: asyncio.AbstractEventLoop | None = None @@ -205,6 +210,15 @@ async def _async_initialize( guard_addr_0 = f"http://{format_hostport(guard_workers[0].ip, int(guard_workers[0].worker_ports[0]))}" master_addr = guard_workers[0].ip + # Persist guard addresses so connect_engine() can allocate + # ports later (e.g. for the weight-update NCCL group). + def _guard_addr(worker: Worker) -> str: + return ( + f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" + ) + + self._guard_addrs = [_guard_addr(w) for w in guard_workers] + client = await self._get_async_client() resp = await client.post( f"{guard_addr_0}/alloc_ports", json={"count": 1}, timeout=30.0 @@ -215,10 +229,6 @@ async def _async_initialize( # ============================================================== # Step 1.5: Set NCCL env on each guard so forked workers inherit it # ============================================================== - def _guard_addr(worker: Worker) -> str: - return ( - f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" - ) await self._async_set_guards_env( guard_workers, @@ -767,6 +777,9 @@ def eval(self) -> GatewayTrainController: def set_version(self, version: int) -> None: from areal.infra.rpc.serialization import serialize_value + with self._version_lock: + self._version = version + self._gateway_post( "/set_version", { @@ -776,7 +789,8 @@ def set_version(self, version: int) -> None: ) def get_version(self) -> int: - return int(self._gateway_get_result("/get_version")) + with self._version_lock: + return self._version def save(self, meta: Any) -> None: from areal.infra.rpc.serialization import serialize_value @@ -832,13 +846,22 @@ def get_device_stats(self) -> Any: return self._gateway_post_result("/get_device_stats", payload) def config_perf_tracer(self, config: Any, role: str) -> None: - from areal.infra.rpc.serialization import serialize_value + self._ensure_initialized() - payload = { - "args": serialize_value([]), - "kwargs": serialize_value({"config": config, "role": role}), - } - self._gateway_post("/config_perf_tracer", payload) + async def _call() -> None: + tasks = [ + self._call_worker_engine_endpoint( + addr, + "/config_perf_tracer", + args=[], + kwargs={"config": config, "rank": rank, "role": role}, + timeout=self.config.request_timeout, + ) + for rank, addr in enumerate(self._worker_addrs) + ] + await asyncio.gather(*tasks) + + run_async_task(_call) def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None: from areal.infra.rpc.serialization import serialize_value @@ -850,10 +873,31 @@ def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None self._gateway_post("/save_perf_tracer", payload) def clear_batches(self, *targets: Any) -> None: + from areal.infra.rpc.rtensor import RTensor, flatten_shard_ids from areal.infra.rpc.serialization import serialize_value + # Step 1: HTTP DELETE to storage nodes to evict _storage entries + # (mirrors TrainController._async_clear_batches) + shards_by_node = RTensor.collect_shards(targets) + if shards_by_node: + + async def _clear_storage(): + await asyncio.gather( + *[ + RTensor.clear_node(addr, sids) + for addr, sids in shards_by_node.items() + ], + return_exceptions=True, + ) + + run_async_task(_clear_storage) + + # Step 2: Drain _fetch_buffer on workers via engine.clear_batches(shard_ids) + shard_ids = flatten_shard_ids(targets) + if not shard_ids: + return payload = { - "args": serialize_value(list(targets)), + "args": serialize_value([shard_ids]), "kwargs": serialize_value({}), } self._gateway_post("/clear_batches", payload) @@ -883,6 +927,135 @@ def data_parallel_rank(self) -> int: def cpu_group(self): return None + @property + def train_worker_urls(self) -> list[str]: + return list(self._worker_addrs) + + # -- RL parity methods (connect_engine / update_weights / batch) -------- + + def connect_engine(self, rollout: Any, meta: Any) -> None: + self._ensure_initialized() + import requests + + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + from areal.experimental.weight_update.controller.config import ( + WeightUpdateControllerConfig, + ) + from areal.experimental.weight_update.controller.controller import ( + WeightUpdateController, + ) + + if not isinstance(rollout, RolloutControllerV2): + raise TypeError( + f"GatewayTrainController requires RolloutControllerV2, " + f"got {type(rollout).__name__}. " + f"Ensure _version='v2' is set on InferenceEngineConfig." + ) + + self.rollout = rollout + + if meta.type != "awex": + raise ValueError( + f"GatewayTrainController only supports 'awex' weight updates, got '{meta.type}'" + ) + + ctrl = WeightUpdateController( + WeightUpdateControllerConfig( + admin_api_key=self.config.admin_api_key, + log_level=self.config.log_level, + ) + ) + ctrl.initialize() + + inference_urls: list[str] = rollout.inference_worker_urls + + nccl_master_addr = "" + nccl_master_port = 0 + if self._guard_addrs: + resp = requests.post( + f"{self._guard_addrs[0]}/alloc_ports", + json={"count": 1}, + timeout=30, + ) + resp.raise_for_status() + port_data = resp.json() + nccl_master_addr = port_data["host"] + nccl_master_port = port_data["ports"][0] + + pair_name = f"{self._role}-rollout" + ctrl.connect( + pair_name=pair_name, + train_worker_urls=self._worker_addrs, + inference_worker_urls=inference_urls, + nccl_master_addr=nccl_master_addr, + nccl_master_port=nccl_master_port, + ) + self._weight_update_ctrl = ctrl + logger.info( + "WeightUpdateController connected (pair=%s, train=%d, inf=%d)", + pair_name, + len(self._worker_addrs), + len(inference_urls), + ) + + def update_weights(self, meta: Any) -> None: + if self._weight_update_ctrl is None or self.rollout is None: + raise RuntimeError( + "connect_engine() must be called before update_weights()" + ) + self.rollout.pause_generation() + assert meta.version is not None and meta.version > 0, ( + f"meta.version must be a positive integer, got {meta.version}" + ) + result = self._weight_update_ctrl.update_weights(version=meta.version) + self.rollout.continue_generation() + logger.info( + "Weight update v%d completed (%s, %.0fms)", + meta.version, + result.status, + result.duration_ms, + ) + + def prepare_batch( + self, + dataloader: Any, + workflow: Any, + workflow_kwargs: dict[str, Any], + should_accept_fn: str | None = None, + group_size: int = 1, + dynamic_bs: bool = False, + ) -> list[dict[str, Any]]: + if self.rollout is None: + raise RuntimeError("connect_engine() must be called before prepare_batch()") + return self.rollout.prepare_batch( + dataloader=dataloader, + workflow=workflow, + workflow_kwargs=workflow_kwargs, + should_accept_fn=should_accept_fn, + group_size=group_size, + dynamic_bs=dynamic_bs, + ) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow: Any, + workflow_kwargs: dict[str, Any], + should_accept_fn: str | None = None, + group_size: int = 1, + ) -> list[dict[str, Any]]: + if self.rollout is None: + raise RuntimeError("connect_engine() must be called before rollout_batch()") + return self.rollout.rollout_batch( + data=data, + workflow=workflow, + workflow_kwargs=workflow_kwargs, + should_accept_fn=should_accept_fn, + group_size=group_size, + ) + def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): self._parallel_strategy = parallel_strategy import torch.distributed as dist diff --git a/areal/experimental/weight_update/awex/fsdp_adapter.py b/areal/experimental/weight_update/awex/fsdp_adapter.py index aff173e06a..9c061bacb8 100644 --- a/areal/experimental/weight_update/awex/fsdp_adapter.py +++ b/areal/experimental/weight_update/awex/fsdp_adapter.py @@ -59,12 +59,18 @@ def parallelism_strategy(self) -> dict: "dp_replicated": False, } + @property + def _tie_word_embeddings(self) -> bool: + return getattr(self._engine.model_config, "tie_word_embeddings", False) + def get_weight_metadata(self) -> list[ParameterMeta]: rank_info = self._build_rank_info() metadata: list[ParameterMeta] = [] for raw_name, param in self._engine.model.named_parameters(): name = self._to_hf_name(raw_name) + if self._tie_word_embeddings and name == "lm_head.weight": + continue tensor = param.data if isinstance(tensor, DTensor): shard_meta = self._extract_dtensor_shard_meta(name, tensor, rank_info) @@ -99,6 +105,8 @@ def get_local_shard_parameters( for raw_name, param in self._engine.model.named_parameters(): name = self._to_hf_name(raw_name) + if self._tie_word_embeddings and name == "lm_head.weight": + continue if required is not None and name not in required: continue diff --git a/areal/experimental/weight_update/controller/controller.py b/areal/experimental/weight_update/controller/controller.py index 01dad28a72..5254f2d594 100644 --- a/areal/experimental/weight_update/controller/controller.py +++ b/areal/experimental/weight_update/controller/controller.py @@ -112,19 +112,24 @@ def connect( save_path: str = "", use_lora: bool = False, lora_name: str = "", + nccl_master_addr: str = "", + nccl_master_port: int = 0, ) -> None: self._pair_name = pair_name + payload: dict[str, Any] = { + "pair_name": pair_name, + "train_worker_urls": train_worker_urls, + "inference_worker_urls": inference_worker_urls, + "mode": mode, + "save_path": save_path, + "use_lora": use_lora, + "lora_name": lora_name, + "nccl_master_addr": nccl_master_addr, + "nccl_master_port": nccl_master_port, + } resp = self._http.post( f"{self._gateway_url}/connect", - json={ - "pair_name": pair_name, - "train_worker_urls": train_worker_urls, - "inference_worker_urls": inference_worker_urls, - "mode": mode, - "save_path": save_path, - "use_lora": use_lora, - "lora_name": lora_name, - }, + json=payload, timeout=self.config.request_timeout, ) resp.raise_for_status() diff --git a/areal/experimental/weight_update/gateway/app.py b/areal/experimental/weight_update/gateway/app.py index c0d8de2a68..f938bc2c91 100644 --- a/areal/experimental/weight_update/gateway/app.py +++ b/areal/experimental/weight_update/gateway/app.py @@ -363,41 +363,6 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse: logger.info("Connected pair '%s'", pair_name) return ConnectResponse(pair_name=pair_name) - @asynccontextmanager - async def _inference_paused( - session: aiohttp.ClientSession, - inference_urls: list[str], - timeout_s: float, - pair_name: str, - ): - await asyncio.gather( - *[ - _post(session, f"{url}/pause_generation", timeout_s, json_data={}) - for url in inference_urls - ] - ) - try: - yield - finally: - try: - await asyncio.gather( - *[ - _post( - session, - f"{url}/continue_generation", - timeout_s, - json_data={}, - ) - for url in inference_urls - ] - ) - except Exception: - logger.warning( - "Failed to resume inference for pair '%s'", - pair_name, - exc_info=True, - ) - async def _awex_transfer_weights( pair_info: PairInfo, version: int, @@ -489,20 +454,14 @@ async def update_weights( start = time.monotonic() try: - async with _inference_paused( - session, - pair_info.inference_worker_urls, - timeout_s, - pair_info.pair_name, - ): - if pair_info.mode == "disk": - await _disk_transfer_weights( - pair_info, body.version, session, timeout_s - ) - else: - await _awex_transfer_weights( - pair_info, body.version, session, timeout_s - ) + if pair_info.mode == "disk": + await _disk_transfer_weights( + pair_info, body.version, session, timeout_s + ) + else: + await _awex_transfer_weights( + pair_info, body.version, session, timeout_s + ) except Exception as e: duration_ms = (time.monotonic() - start) * 1000 logger.error( diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index b484831ddd..ec327cc545 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -474,6 +474,7 @@ def serve(): host="0.0.0.0", port=self._proxy_gateway_port, log_level="warning", + access_log=False, ) server = uvicorn.Server(config) self._proxy_gateway_server = server diff --git a/areal/infra/workflow_context.py b/areal/infra/workflow_context.py index 20bb07b265..beb9513807 100644 --- a/areal/infra/workflow_context.py +++ b/areal/infra/workflow_context.py @@ -11,7 +11,11 @@ import httpx from areal.infra.utils.concurrent import register_loop_cleanup -from areal.infra.utils.http import DEFAULT_REQUEST_TIMEOUT, get_default_connector +from areal.infra.utils.http import ( + DEFAULT_REQUEST_TIMEOUT, + create_httpx_client, + get_default_connector, +) from areal.utils import logging logger = logging.getLogger("WorkflowContext") @@ -126,7 +130,11 @@ async def get_aiohttp_session(self) -> aiohttp.ClientSession: await self._check_event_loop_change() if self._aiohttp_session is None: - timeout = aiohttp.ClientTimeout(total=DEFAULT_REQUEST_TIMEOUT) + timeout = aiohttp.ClientTimeout( + total=DEFAULT_REQUEST_TIMEOUT, + sock_connect=DEFAULT_REQUEST_TIMEOUT, + connect=DEFAULT_REQUEST_TIMEOUT, + ) self._aiohttp_session = aiohttp.ClientSession( timeout=timeout, read_bufsize=1024 * 1024 * 10, @@ -165,9 +173,7 @@ async def get_httpx_client(self) -> httpx.AsyncClient: await self._check_event_loop_change() if self._httpx_client is None: - self._httpx_client = httpx.AsyncClient( - timeout=httpx.Timeout(DEFAULT_REQUEST_TIMEOUT) - ) + self._httpx_client = create_httpx_client(timeout=DEFAULT_REQUEST_TIMEOUT) # Track which event loop this client belongs to self._event_loop = asyncio.get_running_loop() diff --git a/areal/reward/__init__.py b/areal/reward/__init__.py index b38e9fd147..d6b8e42359 100644 --- a/areal/reward/__init__.py +++ b/areal/reward/__init__.py @@ -9,24 +9,6 @@ logger = logging.getLogger("RewardUtils") -VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"] - - -def get_custom_reward_fn(path: str, **kwargs): - if "clevr_count_70k" in path: - from .clevr_count_70k import clevr_count_70k_reward_fn - - return clevr_count_70k_reward_fn - elif "geometry3k" in path: - from .geometry3k import geometry3k_reward_fn - - return geometry3k_reward_fn - else: - raise ValueError( - f"Reward function {path} is not supported. " - f"Supported reward functions are: {VALID_REWARD_FN}. " - ) - class MathVerifyWorker: """Thin wrapper over math_verify with configurable extraction/precision. @@ -120,8 +102,6 @@ def get_math_verify_worker() -> MathVerifyWorker: __all__ = [ - "VALID_REWARD_FN", - "get_custom_reward_fn", "MathVerifyWorker", "get_math_verify_worker", "gsm8k_reward_fn", diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 083e352e27..2eb9c10cb6 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -156,7 +156,7 @@ def __init__( if self._online_mode and config.valid_dataset is not None: raise ValueError( "valid_dataset must not be set when using online RL mode " - "(openai.mode='online'). Online mode does not support " + "(agent.mode='online'). Online mode does not support " "validation datasets." ) @@ -164,7 +164,7 @@ def __init__( if not self._online_mode and train_dataset is None: raise ValueError( "train_dataset must be provided unless using online RL mode " - "(openai.mode='online')." + "(agent.mode='online')." ) # Create models: actor, critic, ref — each with its own allocation. @@ -302,7 +302,18 @@ def __init__( self._proxy_started = False # Prepare weight update meta and connect to inference engine - if self.config.actor.weight_update_mode == "disk": + if self.config.actor._version == "v2": + awex_kwargs: dict[str, Any] = {} + if config.actor.use_lora: + awex_kwargs.update( + { + "use_lora": config.actor.use_lora, + "lora_name": config.gconfig.lora_name, + "base_model_name": config.actor.path, + } + ) + self.weight_update_meta = WeightUpdateMeta.from_awex(**awex_kwargs) + elif self.config.actor.weight_update_mode == "disk": disk_kwargs = { "experiment_name": config.experiment_name, "trial_name": config.trial_name, @@ -1225,6 +1236,15 @@ def _validate_cfg(self): "switch actor backend from Megatron." ) + # Ensure actor and rollout controller versions match. + actor_version = self.config.actor._version + rollout_version = self.config.rollout._version + if actor_version != rollout_version: + raise ValueError( + f"actor._version ('{actor_version}') and rollout._version " + f"('{rollout_version}') must match. Both must be 'v1' or both 'v2'." + ) + def _requires_proxy_workflow(self, workflow: WorkflowLike | None) -> bool: """Check if workflow requires proxy workers (i.e., not a RolloutWorkflow). @@ -1288,8 +1308,11 @@ def _ensure_proxy_started(self) -> None: if self.config.scheduler.type == "ray": raise NotImplementedError("Proxy workers not supported with RayScheduler") - assert isinstance(self.rollout, RolloutController) + if not isinstance(self.rollout, RolloutController): + self._proxy_started = True + return + # v1 controller needs an explicit proxy launch call logger.info("Initializing proxy workers for AgentWorkflow support") self.rollout.start_proxy() if self.eval_rollout is not None: diff --git a/areal/workflow/openai/math_agent.py b/areal/workflow/openai/math_agent.py index 886b89c0d3..f1d7531bcf 100644 --- a/areal/workflow/openai/math_agent.py +++ b/areal/workflow/openai/math_agent.py @@ -29,6 +29,7 @@ def __init__(self, **kwargs): self.kwargs = kwargs.copy() self.kwargs.pop("max_tokens", None) self.kwargs.pop("max_turns", None) + self._reward_fn = AsyncRewardWrapper(math_reward_fn) async def run(self, data: dict, **extra_kwargs): http_client = extra_kwargs.get("http_client", None) @@ -41,8 +42,7 @@ async def run(self, data: dict, **extra_kwargs): messages=data["messages"], model="default", **self.kwargs ) - reward_fn = AsyncRewardWrapper(math_reward_fn) - return await reward_fn( + return await self._reward_fn( completions=comp.choices[0].message.content, answer=data["answer"] ) @@ -52,6 +52,7 @@ def __init__(self, max_turns: int = 8, **kwargs): self.max_turns = max_turns self.kwargs = kwargs.copy() self.kwargs.pop("max_tokens", None) + self._reward_fn = AsyncRewardWrapper(math_reward_fn) async def run(self, data: dict, **extra_kwargs): http_client = extra_kwargs.get("http_client", None) @@ -70,8 +71,9 @@ async def run(self, data: dict, **extra_kwargs): ) message = response.choices[0].message messages.append(message.model_dump(exclude_none=True)) - reward_fn = AsyncRewardWrapper(math_reward_fn) - reward = await reward_fn(completions=message.content, answer=data["answer"]) + reward = await self._reward_fn( + completions=message.content, answer=data["answer"] + ) rewards[response.id] = reward if reward == 1: break @@ -131,6 +133,7 @@ def __init__(self, **kwargs): self.kwargs = kwargs.copy() self.kwargs.pop("max_tokens", None) self.kwargs.pop("max_turns", None) + self._reward_fn = AsyncRewardWrapper(math_reward_fn) async def run(self, data: dict, **extra_kwargs): http_client = extra_kwargs.get("http_client", None) @@ -142,7 +145,7 @@ async def run(self, data: dict, **extra_kwargs): content = data["messages"][-1]["content"] run_config = RunConfig( model_provider=OpenAIProvider(openai_client=client), - model="default", # no need to pass + model="default", tracing_disabled=True, model_settings=ModelSettings(**self.kwargs), ) @@ -163,6 +166,6 @@ async def run(self, data: dict, **extra_kwargs): agent, input=content, session=session, run_config=run_config ) - reward_fn = AsyncRewardWrapper(math_reward_fn) - reward = await reward_fn(completions=result.final_output, answer=data["answer"]) - return reward + return await self._reward_fn( + completions=result.final_output, answer=data["answer"] + ) diff --git a/examples/math/boba_grpo.yaml b/examples/math/boba_grpo.yaml index 694512fd43..e26011248c 100644 --- a/examples/math/boba_grpo.yaml +++ b/examples/math/boba_grpo.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 16 min_new_tokens: 0 max_new_tokens: 8192 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_dapo_dynamic_bs.yaml b/examples/math/gsm8k_dapo_dynamic_bs.yaml index ef5074f118..596047b383 100644 --- a/examples/math/gsm8k_dapo_dynamic_bs.yaml +++ b/examples/math/gsm8k_dapo_dynamic_bs.yaml @@ -32,11 +32,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 4096 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_drgrpo.yaml b/examples/math/gsm8k_drgrpo.yaml index 2991358056..3992769f19 100644 --- a/examples/math/gsm8k_drgrpo.yaml +++ b/examples/math/gsm8k_drgrpo.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_grpo.yaml b/examples/math/gsm8k_grpo.yaml index 18264f920a..3bdd62664c 100644 --- a/examples/math/gsm8k_grpo.yaml +++ b/examples/math/gsm8k_grpo.yaml @@ -30,12 +30,17 @@ rollout: scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} - dump_to_file: true + dump_to_file: false + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_grpo_cpu.yaml b/examples/math/gsm8k_grpo_cpu.yaml index add22b8ae6..573c8ef80c 100644 --- a/examples/math/gsm8k_grpo_cpu.yaml +++ b/examples/math/gsm8k_grpo_cpu.yaml @@ -34,11 +34,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 2 min_new_tokens: 0 max_new_tokens: 256 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_grpo_lora.yaml b/examples/math/gsm8k_grpo_lora.yaml index 4a4473efd4..d98d50a621 100644 --- a/examples/math/gsm8k_grpo_lora.yaml +++ b/examples/math/gsm8k_grpo_lora.yaml @@ -32,11 +32,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 lora_name: "lora-gsm8k" diff --git a/examples/math/gsm8k_grpo_megatron.yaml b/examples/math/gsm8k_grpo_megatron.yaml index 2482b297bc..8b8dfe66a2 100644 --- a/examples/math/gsm8k_grpo_megatron.yaml +++ b/examples/math/gsm8k_grpo_megatron.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_grpo_megatron_fp8.yaml b/examples/math/gsm8k_grpo_megatron_fp8.yaml index 376a54ab1d..9c3a960a5c 100644 --- a/examples/math/gsm8k_grpo_megatron_fp8.yaml +++ b/examples/math/gsm8k_grpo_megatron_fp8.yaml @@ -27,11 +27,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_grpo_megatron_lora.yaml b/examples/math/gsm8k_grpo_megatron_lora.yaml index 3dfff966ac..99628ae2e3 100644 --- a/examples/math/gsm8k_grpo_megatron_lora.yaml +++ b/examples/math/gsm8k_grpo_megatron_lora.yaml @@ -31,12 +31,17 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 use_lora: true gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 lora_name: "lora-gsm8k" diff --git a/examples/math/gsm8k_grpo_megatron_lora_moe.yaml b/examples/math/gsm8k_grpo_megatron_lora_moe.yaml index 8ce555d901..d431504db6 100644 --- a/examples/math/gsm8k_grpo_megatron_lora_moe.yaml +++ b/examples/math/gsm8k_grpo_megatron_lora_moe.yaml @@ -31,12 +31,17 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 use_lora: true gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 lora_name: "lora-gsm8k" diff --git a/examples/math/gsm8k_grpo_npu.yaml b/examples/math/gsm8k_grpo_npu.yaml index 112e5fce05..d4f66aed05 100644 --- a/examples/math/gsm8k_grpo_npu.yaml +++ b/examples/math/gsm8k_grpo_npu.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_gspo.yaml b/examples/math/gsm8k_gspo.yaml index 6caf80a5e3..a76a0046bc 100644 --- a/examples/math/gsm8k_gspo.yaml +++ b/examples/math/gsm8k_gspo.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_liteppo.yaml b/examples/math/gsm8k_liteppo.yaml index 40499d232c..914784c70f 100644 --- a/examples/math/gsm8k_liteppo.yaml +++ b/examples/math/gsm8k_liteppo.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_m2po.yaml b/examples/math/gsm8k_m2po.yaml index ae8fd03641..5788a71711 100644 --- a/examples/math/gsm8k_m2po.yaml +++ b/examples/math/gsm8k_m2po.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_ppo.yaml b/examples/math/gsm8k_ppo.yaml index f3544db2c8..ce8e29fcd5 100644 --- a/examples/math/gsm8k_ppo.yaml +++ b/examples/math/gsm8k_ppo.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_ppo_megatron.yaml b/examples/math/gsm8k_ppo_megatron.yaml index e75a341906..7e6dab30c7 100644 --- a/examples/math/gsm8k_ppo_megatron.yaml +++ b/examples/math/gsm8k_ppo_megatron.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 1 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_reinforce.yaml b/examples/math/gsm8k_reinforce.yaml index 6944cd3bdf..cd9731c718 100644 --- a/examples/math/gsm8k_reinforce.yaml +++ b/examples/math/gsm8k_reinforce.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_reinforce_baseline.yaml b/examples/math/gsm8k_reinforce_baseline.yaml index cfc144b92a..8d6ae20771 100644 --- a/examples/math/gsm8k_reinforce_baseline.yaml +++ b/examples/math/gsm8k_reinforce_baseline.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_rl.py b/examples/math/gsm8k_rl.py index 4dbed5fa60..fef97b0a84 100644 --- a/examples/math/gsm8k_rl.py +++ b/examples/math/gsm8k_rl.py @@ -22,13 +22,13 @@ def main(args): ) workflow_kwargs = dict( - reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", - gconfig=config.gconfig, - tokenizer=config.tokenizer_path, - enable_thinking=False, + temperature=config.gconfig.temperature, + top_p=config.gconfig.top_p, + max_tokens=config.gconfig.max_tokens, + max_completion_tokens=config.gconfig.max_new_tokens, ) eval_workflow_kwargs = workflow_kwargs.copy() - eval_workflow_kwargs["gconfig"] = config.gconfig.new(temperature=0.6) + eval_workflow_kwargs["temperature"] = 0.6 with PPOTrainer( config, @@ -36,9 +36,9 @@ def main(args): valid_dataset=valid_dataset, ) as trainer: trainer.train( - workflow="areal.workflow.rlvr.RLVRWorkflow", + workflow="areal.workflow.openai.math_agent.MathAgent", workflow_kwargs=workflow_kwargs, - eval_workflow="areal.workflow.rlvr.RLVRWorkflow", + eval_workflow="areal.workflow.openai.math_agent.MathAgent", eval_workflow_kwargs=eval_workflow_kwargs, ) diff --git a/examples/math/gsm8k_rloo.yaml b/examples/math/gsm8k_rloo.yaml index 867a6552f7..a1ad57494e 100644 --- a/examples/math/gsm8k_rloo.yaml +++ b/examples/math/gsm8k_rloo.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/math/gsm8k_sapo.yaml b/examples/math/gsm8k_sapo.yaml index da7825e150..96d68abcd5 100644 --- a/examples/math/gsm8k_sapo.yaml +++ b/examples/math/gsm8k_sapo.yaml @@ -31,11 +31,16 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + agent: + mode: inline + export_style: individual + turn_discount: 1.0 gconfig: n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 + max_tokens: 2048 greedy: false temperature: 1.0 diff --git a/examples/openclaw/config.yaml b/examples/openclaw/config.yaml index 21232925a6..5075c17bd4 100644 --- a/examples/openclaw/config.yaml +++ b/examples/openclaw/config.yaml @@ -38,6 +38,7 @@ rollout: export_style: individual turn_discount: 1.0 admin_api_key: sk-test123456 + admin_api_key: sk-test123456 gconfig: n_samples: 1 diff --git a/pyproject.toml b/pyproject.toml index e97e0ca6f6..6a215d42e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ dependencies = [ "python_dateutil", "word2number", "pebble", + "tenacity>=8.2.0", "timeout-decorator", "prettytable", "h5py", diff --git a/pyproject.vllm.toml b/pyproject.vllm.toml index 1e61f0eef0..6e5bfe9cd6 100644 --- a/pyproject.vllm.toml +++ b/pyproject.vllm.toml @@ -121,6 +121,7 @@ dependencies = [ "python_dateutil", "word2number", "pebble", + "tenacity>=8.2.0", "timeout-decorator", "prettytable", "h5py", diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 8e50d003c1..a6d5b92f63 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -412,43 +412,6 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy assert "http://127.0.0.1:19000" in data_proxy_cmd -# ============================================================================= -# RolloutControllerV2 — gateway HTTP helpers -# ============================================================================= - - -class TestRolloutControllerV2HTTP: - def test_gateway_http_post_raises_on_failure(self): - cfg = InferenceEngineConfig(backend="sglang:d1", admin_api_key="test-key") - scheduler = MagicMock(n_gpus_per_node=8) - controller = RolloutControllerV2(config=cfg, scheduler=scheduler) - controller._gateway_addr = "http://127.0.0.1:19999" - with pytest.raises(RuntimeError, match="Failed to POST"): - controller._gateway_http_post("/test", {"key": "value"}) - - def test_gateway_http_post_sends_auth(self): - mock_resp = MagicMock() - mock_resp.status_code = 200 - - cfg = InferenceEngineConfig( - backend="sglang:d1", - admin_api_key="my-secret-key", - ) - scheduler = MagicMock(n_gpus_per_node=8) - controller = RolloutControllerV2(config=cfg, scheduler=scheduler) - controller._gateway_addr = "http://127.0.0.1:8080" - - with patch.object( - controller._sync_client, "post", return_value=mock_resp - ) as mock_post: - controller._gateway_http_post("/test_endpoint", {"data": 1}) - - mock_post.assert_called_once() - call_kwargs = mock_post.call_args - assert "Bearer my-secret-key" in str(call_kwargs) - assert "http://127.0.0.1:8080/test_endpoint" in str(call_kwargs) - - class TestOnlineCallbackFlow: @pytest.mark.asyncio async def test_online_callback_without_waiter_buffers_export_request(self): diff --git a/tests/experimental/inference_service/test_controller_version.py b/tests/experimental/inference_service/test_controller_version.py index 0e3d3b4b86..bba3fa7841 100644 --- a/tests/experimental/inference_service/test_controller_version.py +++ b/tests/experimental/inference_service/test_controller_version.py @@ -63,23 +63,25 @@ def test_set_version_updates_local(self): def test_set_version_no_gateway_skips_broadcast(self): """When _gateway_addr is empty, set_version updates local but makes no HTTP calls.""" ctrl = _make_controller(gateway_addr="", worker_ids={"dp0": "w1"}) + ctrl._data_proxy_addrs = ["http://dp0:8000"] with patch.object( - ctrl, "_async_gateway_http_post", new_callable=AsyncMock + ctrl, "_async_data_proxy_post", new_callable=AsyncMock ) as mock_post: ctrl.set_version(5) mock_post.assert_not_called() assert ctrl._version == 5 def test_set_version_broadcasts_to_all_workers(self): - """When gateway_addr is set and 2 workers exist, broadcasts to both.""" + """When gateway_addr is set and data proxies exist, broadcasts to all.""" ctrl = _make_controller( gateway_addr="http://gateway:8000", worker_ids={"dp0": "w1", "dp1": "w2"}, ) + ctrl._data_proxy_addrs = ["http://dp0:8000", "http://dp1:8000"] mock_post = AsyncMock() - with patch.object(ctrl, "_async_gateway_http_post", mock_post): + with patch.object(ctrl, "_async_data_proxy_post", mock_post): loop = asyncio.new_event_loop() try: loop.run_until_complete(ctrl._async_set_version(10)) @@ -87,9 +89,12 @@ def test_set_version_broadcasts_to_all_workers(self): loop.close() assert mock_post.call_count == 2 - call_endpoints = [call.args[0] for call in mock_post.call_args_list] - assert "/set_version/w1" in call_endpoints - assert "/set_version/w2" in call_endpoints + call_addrs = [call.args[0] for call in mock_post.call_args_list] + assert "http://dp0:8000" in call_addrs + assert "http://dp1:8000" in call_addrs + for call in mock_post.call_args_list: + assert call.args[1] == "/set_version" + assert call.args[2] == {"version": 10} # ============================================================================= diff --git a/tests/test_async_reward_wrapper.py b/tests/test_async_reward_wrapper.py new file mode 100644 index 0000000000..d24dddcd18 --- /dev/null +++ b/tests/test_async_reward_wrapper.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +from concurrent.futures import ProcessPoolExecutor +from concurrent.futures.process import BrokenProcessPool + +import pytest + +from areal.api.reward_api import AsyncRewardWrapper + + +# Module-level: ProcessPoolExecutor requires picklable callables +def _add_reward(a: float, b: float) -> float: + return a + b + + +def _slow_reward(seconds: float) -> float: + time.sleep(seconds) + return 1.0 + + +def _crash_worker() -> float: + os._exit(1) + + +def _raise_value_error() -> float: + raise ValueError("intentional error from reward fn") + + +@pytest.fixture(autouse=True) +def _isolate_executor_state(): + """Shut down and clear all shared executors before and after each test.""" + AsyncRewardWrapper._atexit_shutdown_all() + yield + AsyncRewardWrapper._atexit_shutdown_all() + + +class TestAsyncRewardWrapperBasic: + @pytest.mark.asyncio + async def test_returns_correct_result(self): + wrapper = AsyncRewardWrapper(_add_reward, max_workers=1, max_retries=0) + result = await wrapper(3.0, 4.0) + assert result == 7.0 + + @pytest.mark.asyncio + async def test_multiple_calls_return_correct_results(self): + wrapper = AsyncRewardWrapper(_add_reward, max_workers=2, max_retries=0) + results = [await wrapper(float(i), 1.0) for i in range(5)] + assert results == [1.0, 2.0, 3.0, 4.0, 5.0] + + def test_shared_executor_for_same_max_workers(self): + w1 = AsyncRewardWrapper(_add_reward, max_workers=2) + w2 = AsyncRewardWrapper(_add_reward, max_workers=2) + assert w1._executor_key == w2._executor_key + assert AsyncRewardWrapper._executors.get(2) is not None + assert len(AsyncRewardWrapper._executors) == 1 + + def test_different_executor_for_different_max_workers(self): + AsyncRewardWrapper(_add_reward, max_workers=1) + AsyncRewardWrapper(_add_reward, max_workers=2) + assert len(AsyncRewardWrapper._executors) == 2 + + +class TestAsyncRewardWrapperTimeout: + @pytest.mark.asyncio + async def test_timeout_returns_zero_after_retries(self): + wrapper = AsyncRewardWrapper( + _slow_reward, timeout_seconds=0.1, max_workers=1, max_retries=1 + ) + result = await wrapper(10.0) + assert result == 0 + + @pytest.mark.asyncio + async def test_timeout_no_retries_returns_zero(self): + wrapper = AsyncRewardWrapper( + _slow_reward, timeout_seconds=0.1, max_workers=1, max_retries=0 + ) + result = await wrapper(10.0) + assert result == 0 + + +class TestAsyncRewardWrapperBrokenPool: + @pytest.mark.asyncio + async def test_crash_raises_broken_process_pool(self): + wrapper = AsyncRewardWrapper( + _crash_worker, max_workers=1, max_retries=0, timeout_seconds=5 + ) + with pytest.raises(BrokenProcessPool): + await wrapper() + + @pytest.mark.asyncio + async def test_recreation_replaces_executor(self): + wrapper = AsyncRewardWrapper( + _crash_worker, max_workers=1, max_retries=1, timeout_seconds=5 + ) + executor_before = AsyncRewardWrapper._executors.get(1) + assert executor_before is not None + + with pytest.raises(BrokenProcessPool): + await wrapper() + + executor_after = AsyncRewardWrapper._executors.get(1) + assert executor_after is not executor_before + + +class TestAsyncRewardWrapperExceptionHandling: + @pytest.mark.asyncio + async def test_reward_fn_exception_propagates(self): + wrapper = AsyncRewardWrapper(_raise_value_error, max_workers=1, max_retries=0) + with pytest.raises(ValueError, match="intentional error"): + await wrapper() + + @pytest.mark.asyncio + async def test_reward_fn_exception_retries_then_raises(self): + wrapper = AsyncRewardWrapper(_raise_value_error, max_workers=1, max_retries=2) + with pytest.raises(ValueError, match="intentional error"): + await wrapper() + + @pytest.mark.asyncio + async def test_shutdown_then_call_raises_runtime_error(self): + wrapper = AsyncRewardWrapper(_add_reward, max_workers=1, max_retries=0) + AsyncRewardWrapper._atexit_shutdown_all() + with pytest.raises(RuntimeError, match="has been shut down"): + await wrapper(1.0, 2.0) + + +class TestAsyncRewardWrapperAtexitCleanup: + def test_atexit_clears_all_executors(self): + AsyncRewardWrapper(_add_reward, max_workers=1) + AsyncRewardWrapper(_add_reward, max_workers=2) + assert len(AsyncRewardWrapper._executors) == 2 + + AsyncRewardWrapper._atexit_shutdown_all() + assert len(AsyncRewardWrapper._executors) == 0 + + def test_atexit_is_idempotent(self): + AsyncRewardWrapper(_add_reward, max_workers=1) + AsyncRewardWrapper._atexit_shutdown_all() + AsyncRewardWrapper._atexit_shutdown_all() + assert len(AsyncRewardWrapper._executors) == 0 + + +class TestRecreateExecutorRaceSafety: + def test_recreate_skips_when_already_replaced(self): + AsyncRewardWrapper(_add_reward, max_workers=1) + original = AsyncRewardWrapper._executors[1] + + # Simulate a concurrent thread having already replaced the executor + replacement = ProcessPoolExecutor(max_workers=1) + AsyncRewardWrapper._executors[1] = replacement + + result = AsyncRewardWrapper._recreate_executor(1, 1, original) + assert result is replacement + + replacement.shutdown(wait=False) + + def test_recreate_replaces_when_identity_matches(self): + AsyncRewardWrapper(_add_reward, max_workers=1) + original = AsyncRewardWrapper._executors[1] + + result = AsyncRewardWrapper._recreate_executor(1, 1, original) + assert result is not None + assert result is not original + assert AsyncRewardWrapper._executors[1] is result diff --git a/tests/test_examples.py b/tests/test_examples.py index c2453de3af..27b952cc33 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -163,7 +163,8 @@ def test_countdown_example(tmp_path_factory): @pytest.mark.sglang @pytest.mark.multi_gpu @pytest.mark.ci -def test_gsm8k_grpo(tmp_path_factory): +@pytest.mark.parametrize("_version", ["v1", "v2"]) +def test_gsm8k_grpo(tmp_path_factory, _version): experiments_path = tmp_path_factory.mktemp("experiments") name_resolve_path = tmp_path_factory.mktemp("name_resolve") model_path = get_model_path( @@ -192,9 +193,11 @@ def test_gsm8k_grpo(tmp_path_factory): f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", f"actor.path={model_path}", "scheduler.type=local", + f"+actor._version={_version}", + f"+rollout._version={_version}", timeout=900, ) - assert success, "GSM8K GRPO example failed" + assert success, f"GSM8K GRPO example failed (_version={_version})" @pytest.mark.parametrize( diff --git a/uv.lock b/uv.lock index 48c95e0e99..5e1f67b637 100644 --- a/uv.lock +++ b/uv.lock @@ -548,6 +548,7 @@ requires-dist = [ { name = "swanlab", extras = ["dashboard"], specifier = "==0.6.12" }, { name = "tabulate" }, { name = "tenacity" }, + { name = "tenacity", specifier = ">=8.2.0" }, { name = "tensorboardx" }, { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'cuda-train'", specifier = ">=0.1.9" }, { name = "timeout-decorator" }, diff --git a/uv.vllm.lock b/uv.vllm.lock index 85d1feeb9a..4374ac77b1 100644 --- a/uv.vllm.lock +++ b/uv.vllm.lock @@ -592,6 +592,7 @@ requires-dist = [ { name = "swanlab", extras = ["dashboard"], specifier = "==0.6.12" }, { name = "tabulate" }, { name = "tenacity" }, + { name = "tenacity", specifier = ">=8.2.0" }, { name = "tensorboardx" }, { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'cuda-train'", specifier = ">=0.1.9" }, { name = "timeout-decorator" }, From 8a205938b3bfd791334ca0e86f2b5e2d059b1843 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=9A=E6=83=9F?= Date: Mon, 11 May 2026 15:44:15 +0800 Subject: [PATCH 2/2] fix(archon): abandon rollout group when any session fails Partial groups produce inconsistent training data. Reject the entire group if any _run_one call raises, instead of silently returning the successful subset with 0.0 rewards for failures. --- .../inference_service/controller/workflow.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index 4ab4345b26..f7d3c4ac73 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -147,7 +147,8 @@ async def _run_offline( assert self.agent is not None http_client = await workflow_context.get_httpx_client() - async def _run_one(session_id: str, session_api_key: str) -> float: + async def _run_one(session_id: str, session_api_key: str) -> float | None: + """Run one agent session. Returns reward on success, ``None`` on failure.""" try: rewards = await self.agent.run( data, @@ -191,13 +192,16 @@ async def _run_one(session_id: str, session_api_key: str) -> float: session_id, group_id, ) - return 0.0 + return None - rewards = await asyncio.gather( + results = await asyncio.gather( *[_run_one(sid, api_key) for sid, api_key in sessions] ) session_ids = [sid for sid, _ in sessions] + + # Always export to trigger session cleanup on the data proxy, + # even when we intend to discard the trajectories. traj = await self._export_interactions( http_session, session_ids, @@ -206,8 +210,18 @@ async def _run_one(session_id: str, session_api_key: str) -> float: if not traj: return None + n_failed = sum(r is None for r in results) + if n_failed > 0: + logger.warning( + "Abandoning group %s: %d/%d sessions failed", + group_id, + n_failed, + len(sessions), + ) + return None + tracker = stats_tracker.get(workflow_context.stat_scope()) - for r in rewards: + for r in results: tracker.scalar(reward=r) return traj