Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion miles/ray/train/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
from miles.ray.train.cell import RayTrainCell
from miles.ray.train.cell_monitor import create_trainer_cell_health_checker
from miles.utils.async_utils import AsyncioGatherUtils
from miles.utils.checksum_utils import flatten_inference_engine_checksums
from miles.utils.event_analyzer import analyzer as event_analyzer
from miles.utils.event_logger.logger import get_event_logger, is_event_logger_initialized
from miles.utils.event_logger.models import CellReconfigureEvent, WitnessAllocateIdEvent
from miles.utils.event_logger.models import (
CellReconfigureEvent,
InferenceEngineWeightChecksumEvent,
TrainGroupStepEndEvent,
WitnessAllocateIdEvent,
)
from miles.utils.health_checker import NoopHealthChecker, SimpleHealthCheckerConfig
from miles.utils.indep_dp import IndepDPInfo
from miles.utils.megatron_args_utils import compute_megatron_world_size_except_dp
Expand Down Expand Up @@ -123,6 +130,8 @@ def _create_cell(cell_index: int):
async def train(self, rollout_id: int, rollout_data_pack):
"""Do one rollout training"""

event_analyzer.run_analysis_from_args(self.args)

async def _fn(attempt: int):
witness_info = self._allocate_witness_info(
rollout_id=rollout_id,
Expand All @@ -141,6 +150,12 @@ async def _fn(attempt: int):
)
self._check_train_one_attempt(snapshot_alive_cells, results)

self._log_step_end_event(
rollout_id=rollout_id,
snapshot_alive_cells=snapshot_alive_cells,
results=results,
)

await retry(_fn)

def _allocate_witness_info(self, *, rollout_id: int, attempt: int, sample_indices):
Expand All @@ -162,6 +177,17 @@ def _allocate_witness_info(self, *, rollout_id: int, attempt: int, sample_indice

return witness_info

def _log_step_end_event(self, *, rollout_id: int, snapshot_alive_cells: list, results: list):
if is_event_logger_initialized():
cell_outcomes = {
cell.cell_index: ("error" if isinstance(cell_results, BaseException) else [r for r in cell_results])
for cell, cell_results in zip(snapshot_alive_cells, results, strict=True)
}
get_event_logger().log(
TrainGroupStepEndEvent,
dict(rollout_id=rollout_id, cell_outcomes=cell_outcomes),
)

@staticmethod
def _check_train_one_attempt(snapshot_alive_cells, results):
outcomes = RayTrainGroup._compute_attempt_outcomes(snapshot_alive_cells, results)
Expand Down Expand Up @@ -227,6 +253,21 @@ async def update_weights(self, rollout_id: int | None = None):
# Catch with vanilla retry: cells w/ exceptions are auto marked errored, thus retry will find the next one
await retry(lambda _: self._execute_first_alive("update_weights", info=info))

await self._maybe_log_inference_engine_weight_checksums(rollout_id=rollout_id)

async def _maybe_log_inference_engine_weight_checksums(self, *, rollout_id: int | None) -> None:
if not is_event_logger_initialized():
return
if self.args.debug_train_only or self.args.debug_rollout_only:
return

check_weights_result = await self._rollout_manager.check_weights.remote("checksum")
engine_checksums = flatten_inference_engine_checksums(check_weights_result)
get_event_logger().log(
InferenceEngineWeightChecksumEvent,
dict(rollout_id=rollout_id, engine_checksums=engine_checksums),
)

async def onload(self):
# Catch *without* retry: cells w/ exceptions are auto marked errored, and will not be used
await self._execute_all_alive_and_catch("wake_up")
Expand Down
127 changes: 126 additions & 1 deletion tests/fast/ray/train/test_group.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Callable
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import ray
Expand Down Expand Up @@ -939,3 +939,128 @@ def test_returns_witness_info_when_enabled(self):
assert result is not None
assert len(result.witness_ids) == 3
assert isinstance(result.stale_ids, list)


class TestLogStepEndEvent:
def test_with_normal_and_error_cells(self):
"""Passes correct cell_outcomes to event logger for a mix of normal and errored cells."""
group = _make_group(num_cells=3)

mock_cell_0 = MagicMock()
mock_cell_0.cell_index = 0
mock_cell_1 = MagicMock()
mock_cell_1.cell_index = 1
mock_cell_2 = MagicMock()
mock_cell_2.cell_index = 2

snapshot_alive_cells = [mock_cell_0, mock_cell_1, mock_cell_2]
results = [
[TrainStepOutcome.NORMAL, TrainStepOutcome.NORMAL],
RuntimeError("boom"),
[TrainStepOutcome.NORMAL],
]

with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True), patch(
"miles.ray.train.group.get_event_logger"
) as mock_get_logger:
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger

group._log_step_end_event(
rollout_id=42,
snapshot_alive_cells=snapshot_alive_cells,
results=results,
)

mock_logger.log.assert_called_once()
args = mock_logger.log.call_args[0]
partial = args[1]
assert partial["rollout_id"] == 42

cell_outcomes = partial["cell_outcomes"]
assert cell_outcomes[0] == [TrainStepOutcome.NORMAL, TrainStepOutcome.NORMAL]
assert cell_outcomes[1] == "error"
assert cell_outcomes[2] == [TrainStepOutcome.NORMAL]


def _checksum_response(engine_checksums: list[dict[str, str]]) -> list:
"""Build a nested servers->groups->engines check_weights('checksum') response."""
engines = [
{"success": True, "message": "ok", "ranks": [{"checksums": cs, "parallelism_info": {"rank": 0}}]}
for cs in engine_checksums
]
return [[engines]]


class TestMaybeLogInferenceEngineWeightChecksums:
async def test_no_event_logger_does_not_call_check_weights(self):
"""Without an initialized event logger, no check_weights request is issued."""
rollout_mgr = MagicMock()
rollout_mgr.check_weights = MagicMock()
group = _make_group(num_cells=1, rollout_manager=rollout_mgr)

with patch("miles.ray.train.group.is_event_logger_initialized", return_value=False):
await group._maybe_log_inference_engine_weight_checksums(rollout_id=0)

rollout_mgr.check_weights.assert_not_called()

async def test_none_rollout_id_logs_event(self):
"""The initial out-of-loop sync (rollout_id=None) still logs an event with rollout_id=None."""
rollout_mgr = MagicMock()
rollout_mgr.check_weights.remote = AsyncMock(return_value=_checksum_response([{"w": "e0"}]))
group = _make_group(num_cells=1, rollout_manager=rollout_mgr)

with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True), patch(
"miles.ray.train.group.get_event_logger"
) as mock_get_logger:
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger

await group._maybe_log_inference_engine_weight_checksums(rollout_id=None)

mock_logger.log.assert_called_once()
logged = mock_logger.log.call_args.args[1]
assert logged == dict(rollout_id=None, engine_checksums=[{"rank0/w": "e0"}])

async def test_debug_train_only_skips_collection(self):
"""Without real rollout engines (debug_train_only), no check_weights request is issued."""
rollout_mgr = MagicMock()
rollout_mgr.check_weights = MagicMock()
group = _make_group(num_cells=1, rollout_manager=rollout_mgr)
group.args.debug_train_only = True

with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True):
await group._maybe_log_inference_engine_weight_checksums(rollout_id=0)

rollout_mgr.check_weights.assert_not_called()

async def test_debug_rollout_only_skips_collection(self):
"""Without real train engines pushing weights (debug_rollout_only), no check_weights request is issued."""
rollout_mgr = MagicMock()
rollout_mgr.check_weights = MagicMock()
group = _make_group(num_cells=1, rollout_manager=rollout_mgr)
group.args.debug_rollout_only = True

with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True):
await group._maybe_log_inference_engine_weight_checksums(rollout_id=0)

rollout_mgr.check_weights.assert_not_called()

async def test_enabled_logs_one_event_per_rollout(self):
"""With event logger on and real engines, one event holds every engine's checksums."""
rollout_mgr = MagicMock()
rollout_mgr.check_weights.remote = AsyncMock(return_value=_checksum_response([{"w": "e0"}, {"w": "e1"}]))
group = _make_group(num_cells=1, rollout_manager=rollout_mgr)

with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True), patch(
"miles.ray.train.group.get_event_logger"
) as mock_get_logger:
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger

await group._maybe_log_inference_engine_weight_checksums(rollout_id=3)

rollout_mgr.check_weights.remote.assert_awaited_once_with("checksum")
mock_logger.log.assert_called_once()
logged = mock_logger.log.call_args.args[1]
assert logged == dict(rollout_id=3, engine_checksums=[{"rank0/w": "e0"}, {"rank0/w": "e1"}])
Loading