Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions miles/ray/train/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import ray

from miles.ray.train.cell_monitor import compute_cell_status
from miles.ray.train.cell_state import (
CellState,
StateAllocatedAlive,
Expand All @@ -14,6 +15,7 @@
StatePending,
StateStopped,
)
from miles.utils.control_server.models import CellStatus
from miles.utils.health_checker import BaseHealthChecker
from miles.utils.indep_dp import IndepDPInfo
from miles.utils.structured_log import log_structured
Expand Down Expand Up @@ -296,6 +298,9 @@ def is_stopped(self) -> bool:
def state_name(self) -> str:
return type(self._state).__name__

def cell_status(self) -> CellStatus:
return compute_cell_status(self._state, self.health_checker.status)

@property
def indep_dp_info(self) -> IndepDPInfo:
assert isinstance(self._state, (StateAllocatedAlive, StateAllocatedErrored))
Expand Down
77 changes: 77 additions & 0 deletions miles/ray/train/cell_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import TYPE_CHECKING

from miles.ray.train.cell_state import (
CellState,
StateAllocatedAlive,
StateAllocatedErrored,
StateAllocatedUninitialized,
StatePending,
StateStopped,
)
from miles.utils.control_server.models import CellCondition, CellStatus, TriState
from miles.utils.health_checker import SimpleHealthChecker, SimpleHealthCheckerConfig

if TYPE_CHECKING:
from miles.ray.train.cell import RayTrainCell


def create_trainer_cell_health_checker(
*,
cell: "RayTrainCell",
config: SimpleHealthCheckerConfig,
) -> SimpleHealthChecker:
async def _check() -> None:
# Cell health is liveness, not training progress: the heartbeat RPC runs on
# a dedicated concurrency group and returns even while the training thread is
# blocked in a (legitimately waiting) cross-cell collective. A returned result
# proves the process is alive; an RayActorError or RPC timeout proves it is not.
if not cell.is_alive:
return

await cell.execute("get_heartbeat_status", mark_errored_on_failure=False)

return SimpleHealthChecker(
name=f"trainer-cell-{cell.cell_index}",
check_fn=_check,
config=config,
)


def compute_cell_status(state: CellState, health_checker_status: TriState) -> CellStatus:
match state:
case StateAllocatedAlive():
match health_checker_status:
case TriState.FALSE:
healthy = CellCondition.healthy(TriState.FALSE, reason="HealthCheckFailed")
case TriState.UNKNOWN:
healthy = CellCondition.healthy(TriState.UNKNOWN, reason="HealthCheckUnknown")
case TriState.TRUE:
healthy = CellCondition.healthy(TriState.TRUE)
return CellStatus(phase="Running", conditions=[CellCondition.allocated(TriState.TRUE), healthy])

case StateAllocatedUninitialized():
return CellStatus(
phase="Running",
conditions=[
CellCondition.allocated(TriState.TRUE),
CellCondition.healthy(TriState.TRUE),
],
)

case StateAllocatedErrored():
return CellStatus(
phase="Running",
conditions=[
CellCondition.allocated(TriState.TRUE),
CellCondition.healthy(TriState.FALSE, reason="ExecutionErrored"),
],
)

case StatePending():
return CellStatus(phase="Pending", conditions=[CellCondition.allocated(TriState.FALSE)])

case StateStopped():
return CellStatus(phase="Suspended", conditions=[CellCondition.allocated(TriState.FALSE)])

case _:
raise NotImplementedError(f"Unknown state: {state}")
13 changes: 12 additions & 1 deletion miles/ray/train/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from miles.backends.megatron_utils.types import TrainStepOutcome
from miles.ray.train.actor_factory import allocate_gpus_for_actor
from miles.ray.train.cell import RayTrainCell
from miles.ray.train.cell_monitor import create_trainer_cell_health_checker
from miles.utils.async_utils import AsyncioGatherUtils
from miles.utils.event_logger.logger import get_event_logger, is_event_logger_initialized
from miles.utils.event_logger.models import CellReconfigureEvent, WitnessAllocateIdEvent
from miles.utils.health_checker import NoopHealthChecker
from miles.utils.health_checker import NoopHealthChecker, SimpleHealthCheckerConfig
from miles.utils.indep_dp import IndepDPInfo
from miles.utils.megatron_args_utils import compute_megatron_world_size_except_dp
from miles.utils.retry_utils import retry
Expand Down Expand Up @@ -75,6 +76,10 @@ def __init__(
else:
self._indep_dp_store, indep_dp_store_addr = None, None

health_checker_config = (
SimpleHealthCheckerConfig.from_args(args, prefix="trainer_heartbeat_checker") if num_cells > 1 else None
)

def _create_cell(cell_index: int):
cell_pg = _slice_pg(pg, start=cell_index * gpus_per_cell, end=(cell_index + 1) * gpus_per_cell)

Expand All @@ -97,6 +102,12 @@ def _create_cell(cell_index: int):
health_checker=NoopHealthChecker(),
)

if health_checker_config is not None:
cell.health_checker = create_trainer_cell_health_checker(
cell=cell,
config=health_checker_config,
)

return cell

self._cells: list[RayTrainCell] = [_create_cell(cell_index) for cell_index in range(num_cells)]
Expand Down
2 changes: 2 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from miles.utils.chat_template_utils.tito_tokenizer import TITOTokenizerType
from miles.utils.environ import enable_experimental_rollout_refactor
from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list
from miles.utils.health_checker import SimpleHealthCheckerConfig
from miles.utils.hf_config import is_dsa, load_hf_config
from miles.utils.logging_utils import configure_logger_raw
from miles.utils.megatron_args_utils import compute_megatron_world_size_except_dp
Expand Down Expand Up @@ -618,6 +619,7 @@ def add_fault_tolerance_arguments(parser):
default=0,
help="Initial grace period (in seconds) before starting health checks. This allows time for model compilation and initialization. Increase this value significantly when using deepgemm.",
)
SimpleHealthCheckerConfig.add_arguments(parser, prefix="trainer-heartbeat-checker")
return parser

# data
Expand Down
151 changes: 151 additions & 0 deletions tests/fast/ray/train/test_cell_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from unittest.mock import AsyncMock, MagicMock

import pytest
import ray

from miles.ray.train.cell_monitor import compute_cell_status, create_trainer_cell_health_checker
from miles.ray.train.cell_state import (
StateAllocatedAlive,
StateAllocatedErrored,
StateAllocatedUninitialized,
StatePending,
StateStopped,
)
from miles.utils.control_server.models import TriState
from miles.utils.health_checker import SimpleHealthCheckerConfig
from miles.utils.indep_dp import IndepDPInfo


def _make_actor_handle_mock() -> MagicMock:
return MagicMock(spec=ray.actor.ActorHandle)


def _make_indep_dp_info() -> IndepDPInfo:
return IndepDPInfo(
cell_index=0,
num_cells=1,
alive_rank=0,
alive_size=1,
quorum_id=1,
alive_cell_indices=[0],
)


def _make_alive_state() -> StateAllocatedAlive:
return StateAllocatedAlive(actor_handles=[_make_actor_handle_mock()], indep_dp_info=_make_indep_dp_info())


def _find_condition(status, type_: str):
matches = [c for c in status.conditions if c.type == type_]
assert len(matches) == 1, f"expected exactly one {type_!r} condition, got {len(matches)}"
return matches[0]


class TestComputeCellStatusAlive:
def test_health_true_reports_healthy_true(self):
result = compute_cell_status(_make_alive_state(), TriState.TRUE)

assert result.phase == "Running"
healthy = _find_condition(result, "Healthy")
assert healthy.status == TriState.TRUE
assert healthy.reason is None

def test_health_false_reports_healthy_false_with_failed_reason(self):
result = compute_cell_status(_make_alive_state(), TriState.FALSE)

healthy = _find_condition(result, "Healthy")
assert healthy.status == TriState.FALSE
assert healthy.reason == "HealthCheckFailed"

def test_health_unknown_reports_healthy_unknown_not_translated_to_true(self):
"""Regression: paused health checker reports UNKNOWN; previously this was
silently translated to Healthy=TRUE, hiding the transient state from observers."""
result = compute_cell_status(_make_alive_state(), TriState.UNKNOWN)

healthy = _find_condition(result, "Healthy")
assert healthy.status == TriState.UNKNOWN
assert healthy.reason == "HealthCheckUnknown"


class TestComputeCellStatusOtherStates:
@pytest.mark.parametrize("health_status", [TriState.TRUE, TriState.FALSE, TriState.UNKNOWN])
def test_uninitialized_ignores_health_checker(self, health_status: TriState):
state = StateAllocatedUninitialized(actor_handles=[_make_actor_handle_mock()])

result = compute_cell_status(state, health_status)

assert result.phase == "Running"
healthy = _find_condition(result, "Healthy")
assert healthy.status == TriState.TRUE

@pytest.mark.parametrize("health_status", [TriState.TRUE, TriState.FALSE, TriState.UNKNOWN])
def test_errored_always_reports_unhealthy(self, health_status: TriState):
state = StateAllocatedErrored(actor_handles=[_make_actor_handle_mock()], indep_dp_info=_make_indep_dp_info())

result = compute_cell_status(state, health_status)

healthy = _find_condition(result, "Healthy")
assert healthy.status == TriState.FALSE
assert healthy.reason == "ExecutionErrored"

def test_pending_reports_allocated_false_no_healthy_condition(self):
result = compute_cell_status(StatePending(), TriState.UNKNOWN)

assert result.phase == "Pending"
allocated = _find_condition(result, "Allocated")
assert allocated.status == TriState.FALSE
assert all(c.type != "Healthy" for c in result.conditions)

def test_stopped_reports_suspended_phase_no_healthy_condition(self):
result = compute_cell_status(StateStopped(), TriState.UNKNOWN)

assert result.phase == "Suspended"
assert all(c.type != "Healthy" for c in result.conditions)


def _make_cell_mock(*, is_alive: bool, execute: AsyncMock) -> MagicMock:
cell = MagicMock()
cell.is_alive = is_alive
cell.cell_index = 0
cell.execute = execute
return cell


class TestTrainerCellHealthCheckLiveness:
"""Cell health is defined as process liveness (RPC reachability), not training
progress: the check must succeed whenever the heartbeat RPC returns, and fail only
when the actor is dead/unresponsive."""

@pytest.mark.asyncio
async def test_rpc_returns_means_healthy_regardless_of_progress(self):
"""A returned heartbeat (even a stale one) proves liveness, so _check passes."""
execute = AsyncMock(return_value=[MagicMock(last_active_timestamp=0.0, bump_count=1)])
cell = _make_cell_mock(is_alive=True, execute=execute)

checker = create_trainer_cell_health_checker(cell=cell, config=SimpleHealthCheckerConfig())

await checker._check_fn()
execute.assert_awaited_once_with("get_heartbeat_status", mark_errored_on_failure=False)

@pytest.mark.asyncio
async def test_rpc_error_propagates_as_unhealthy(self):
"""A dead actor makes the heartbeat RPC raise; _check propagates it so the
checker reports unhealthy."""
execute = AsyncMock(side_effect=ray.exceptions.RayActorError())
cell = _make_cell_mock(is_alive=True, execute=execute)

checker = create_trainer_cell_health_checker(cell=cell, config=SimpleHealthCheckerConfig())

with pytest.raises(ray.exceptions.RayActorError):
await checker._check_fn()

@pytest.mark.asyncio
async def test_not_alive_cell_skips_rpc(self):
"""A not-yet-alive cell is not probed and is not reported unhealthy."""
execute = AsyncMock()
cell = _make_cell_mock(is_alive=False, execute=execute)

checker = create_trainer_cell_health_checker(cell=cell, config=SimpleHealthCheckerConfig())

await checker._check_fn()
execute.assert_not_awaited()
51 changes: 51 additions & 0 deletions tests/fast/ray/train/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,57 @@ async def test_healing_failure_marks_pending_cell_errored_keeps_alive(self):
assert group._cells[2].is_errored


class TestHeartbeatMonitor:
async def test_heartbeat_normal_does_not_mark_errored(self):
"""When heartbeat returns recent timestamp, cells stay alive."""
group = await _make_alive_group(num_cells=2)

for cell in group._cells:
await cell.health_checker._check_fn()

assert all(c.is_alive for c in group._cells)

async def test_heartbeat_stale_timestamp_does_not_mark_errored(self):
"""A stale heartbeat timestamp alone keeps the cell healthy: cell health is
liveness, not training progress, so a cell legitimately blocked in a cross-cell
collective (whose training loop stops bumping the heartbeat) must not be reported
unhealthy as long as the heartbeat RPC still returns."""
group = await _make_alive_group(num_cells=2)

# Drive cell 1's last-active timestamp to the epoch (maximally stale); the
# liveness check must ignore staleness while the heartbeat RPC keeps returning.
for handle in group._cells[1]._get_actor_handles():
ray.get(handle.set_last_active_timestamp.remote(0.0))

# Neither check raises (a returned heartbeat proves the process is alive) and
# both cells stay alive despite cell 1's stale timestamp.
await group._cells[1].health_checker._check_fn()
await group._cells[0].health_checker._check_fn()
assert all(c.is_alive for c in group._cells)

async def test_heartbeat_timeout_marks_errored(self):
"""When heartbeat call fails (actor unresponsive), cell is marked errored."""
group = await _make_alive_group(num_cells=2)

for handle in group._cells[0]._get_actor_handles():
ray.get(handle.set_heartbeat_fail.remote(True))

with pytest.raises(RuntimeError, match="Injected heartbeat failure"):
await group._cells[0].health_checker._check_fn()

async def test_pause_resume(self):
"""Pause/resume on cell propagates to its checker."""
group = await _make_alive_group(num_cells=2)

for cell in group._cells:
cell.health_checker.pause()
assert all(c.health_checker._paused for c in group._cells)

for cell in group._cells:
cell.health_checker.resume()
assert all(not c.health_checker._paused for c in group._cells)


def _make_mock_cells(n: int) -> list[MagicMock]:
return [MagicMock(health_checker=MagicMock()) for _ in range(n)]

Expand Down
Loading