Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 69 additions & 95 deletions areal/api/reward_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import atexit
import os
import threading
import traceback
import weakref
from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
Expand Down Expand Up @@ -61,15 +60,18 @@ def reward_fn(


class AsyncRewardWrapper:
"""
Wraps a synchronous reward function to make it async with timeout handling.
Automatically manages ProcessPoolExecutor lifecycle based on instance count.
"""Wraps a synchronous reward function for async execution with timeout and retries.

Executors are shared by ``max_workers`` key and cleaned up via ``atexit``.
Includes automatic recovery from broken process pools.

The reward function and its arguments must be picklable since they
are dispatched to worker processes via ``ProcessPoolExecutor``.
"""

_executors = {}
_instance_counts = {}
_executors: dict[int, ProcessPoolExecutor] = {}
_lock = threading.Lock()
_atexit_registered = False

def __init__(
self,
Expand All @@ -83,8 +85,6 @@ def __init__(
if max_workers is None:
cpu_count = os.cpu_count() or 1
device_count = _get_device_count_safely()
# Heuristic for max_workers: distribute CPU cores across devices,
# then halve to be conservative, ensuring at least one worker.
max_workers = max((cpu_count // device_count) // 2, 1)
self.max_workers = max_workers
self.max_retries = max_retries
Expand All @@ -95,115 +95,89 @@ def __init__(
self._executors[self._executor_key] = ProcessPoolExecutor(
max_workers=max_workers
)
self._instance_counts[self._executor_key] = 0
self._instance_counts[self._executor_key] += 1

weakref.finalize(self, AsyncRewardWrapper._cleanup_executor, max_workers)
if not AsyncRewardWrapper._atexit_registered:
atexit.register(AsyncRewardWrapper._atexit_shutdown_all)
AsyncRewardWrapper._atexit_registered = True

@classmethod
def _cleanup_executor(cls, executor_key):
"""Called when an AsyncRewardWrapper instance is garbage collected"""
def _atexit_shutdown_all(cls):
"""Shut down all executors before ``_python_exit`` to prevent
worker processes deadlocking in ``_finalize_join``.
Must use ``wait=True`` so the result queue is fully drained.
"""
with cls._lock:
if executor_key in cls._instance_counts:
cls._instance_counts[executor_key] -= 1
if cls._instance_counts[executor_key] <= 0:
if executor_key in cls._executors:
executor = cls._executors.pop(executor_key)
executor.shutdown(wait=False, cancel_futures=True)
logger.debug(
f"ProcessPoolExecutor with {executor_key} workers shut down"
)
cls._instance_counts.pop(executor_key, None)

@classmethod
def _recreate_executor(cls, executor_key, max_workers):
"""Recreate a broken ProcessPoolExecutor"""
with cls._lock:
if executor_key in cls._executors:
# Clean up the broken executor
old_executor = cls._executors[executor_key]
for executor in cls._executors.values():
try:
old_executor.shutdown(wait=False)
executor.shutdown(wait=True, cancel_futures=True)
except Exception as e:
logger.warning(f"Error shutting down broken executor: {e}")
logger.warning(f"Error shutting down executor at exit: {e}")
cls._executors.clear()

# Create a new executor
cls._executors[executor_key] = ProcessPoolExecutor(
max_workers=max_workers
)
logger.info(f"Recreated ProcessPoolExecutor with {max_workers} workers")
return cls._executors[executor_key]
return None
@classmethod
def _recreate_executor(
cls,
executor_key: int,
max_workers: int,
broken: ProcessPoolExecutor,
) -> ProcessPoolExecutor | None:
with cls._lock:
current = cls._executors.get(executor_key)
if current is not broken:
return current
try:
broken.shutdown(wait=False)
except Exception as e:
logger.warning(f"Error shutting down broken executor: {e}")
try:
new_executor = ProcessPoolExecutor(max_workers=max_workers)
except Exception:
logger.exception("Failed to create replacement ProcessPoolExecutor")
cls._executors.pop(executor_key, None)
return None
cls._executors[executor_key] = new_executor
logger.info(f"Recreated ProcessPoolExecutor with {max_workers} workers")
return new_executor

async def __call__(self, *args, **kwargs) -> float:
last_exception = None

for attempt in range(self.max_retries + 1):
with self._lock:
executor = self._executors.get(self._executor_key)

executor = self._executors.get(self._executor_key)
if executor is None:
raise RuntimeError("ProcessPoolExecutor has been shut down")

loop = asyncio.get_event_loop()
is_last = attempt == self.max_retries
try:
future = loop.run_in_executor(
future = asyncio.get_running_loop().run_in_executor(
executor,
partial(self.reward_fn, *args, **kwargs),
)
reward = await asyncio.wait_for(
future,
timeout=self.timeout_seconds,
)
return reward
return await asyncio.wait_for(future, timeout=self.timeout_seconds)
except TimeoutError:
last_exception = TimeoutError(
f"Reward computation timed out after {self.timeout_seconds}s"
)
logger.warning(
f"Computing reward timeout after {self.timeout_seconds}s "
f"(attempt {attempt + 1}/{self.max_retries + 1}). "
f"{'Retrying...' if attempt < self.max_retries else 'Returning 0.'}"
f"{'Returning 0.' if is_last else 'Retrying...'}"
)
if attempt < self.max_retries:
continue
return 0
except BrokenProcessPool as e:
last_exception = e
if is_last:
return 0
except BrokenProcessPool:
logger.warning(
f"ProcessPoolExecutor broken (attempt {attempt + 1}/{self.max_retries + 1}). "
"Attempting to recreate..."
)
if attempt < self.max_retries:
# Try to recreate the executor
new_executor = self._recreate_executor(
self._executor_key, self.max_workers
)
if new_executor is None:
logger.error("Failed to recreate ProcessPoolExecutor")
break
# Continue to next attempt
continue
else:
logger.error("Max retries exceeded for BrokenProcessPool.")
traceback.print_exc()
raise e
except Exception as e:
last_exception = e
logger.error(f"Unexpected error in reward computation: {e}")
if attempt < self.max_retries:
logger.info(
f"Retrying... (attempt {attempt + 1}/{self.max_retries + 1})"
if is_last:
raise
if (
self._recreate_executor(
self._executor_key, self.max_workers, executor
)
continue
else:
logger.error("Max retries exceeded for unexpected error.")
traceback.print_exc()
raise e

# If we get here, all retries failed
if last_exception:
traceback.print_exc()
raise last_exception
else:
raise RuntimeError("Reward computation failed after all retries.")
is None
):
raise
except Exception:
logger.exception(
f"Reward computation error (attempt {attempt + 1}/{self.max_retries + 1})"
)
if is_last:
raise

return 0
10 changes: 7 additions & 3 deletions areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,13 @@ def export_stats(self) -> dict[str, float]:
return stats_tracker.export_all(reduce_group=None)

@classmethod
def as_controller(
cls, config: InferenceEngineConfig, scheduler: Scheduler
) -> RolloutController:
def as_controller(cls, config: InferenceEngineConfig, scheduler: Scheduler):
if config._version == "v2":
from areal.experimental.inference_service.controller.controller import (
RolloutControllerV2,
)

return RolloutControllerV2(config=config, scheduler=scheduler)
return RolloutController(cls, config=config, scheduler=scheduler)

def clear_batches(self, shard_ids: list[str]) -> None:
Expand Down
10 changes: 7 additions & 3 deletions areal/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,13 @@ def export_stats(self) -> dict[str, float]:
return stats_tracker.export_all(reduce_group=None)

@classmethod
def as_controller(
cls, config: InferenceEngineConfig, scheduler: Scheduler
) -> RolloutController:
def as_controller(cls, config: InferenceEngineConfig, scheduler: Scheduler):
if config._version == "v2":
from areal.experimental.inference_service.controller.controller import (
RolloutControllerV2,
)

return RolloutControllerV2(config=config, scheduler=scheduler)
return RolloutController(cls, config=config, scheduler=scheduler)

def clear_batches(self, shard_ids: list[str]) -> None:
Expand Down
Loading
Loading