From f8190b4e629f8ed7cd6d5369727a4b09b3e51400 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 22 Jun 2026 18:15:19 +0800 Subject: [PATCH] Log train-group step-end and analysis events Adds observability to RayTrainGroup: run the event analyzer at the start of each train() rollout, emit a TrainGroupStepEndEvent with per-cell outcomes after each attempt, and emit an InferenceEngineWeightChecksumEvent after update_weights (collected via rollout_manager.check_weights). Includes the matching unit tests. The witness-id and cell-reconfigure events are added with their own features. --- miles/ray/train/group.py | 43 +++++++++- tests/fast/ray/train/test_group.py | 127 ++++++++++++++++++++++++++++- 2 files changed, 168 insertions(+), 2 deletions(-) diff --git a/miles/ray/train/group.py b/miles/ray/train/group.py index 0ceac8e224..87d1944421 100644 --- a/miles/ray/train/group.py +++ b/miles/ray/train/group.py @@ -13,8 +13,15 @@ from miles.ray.train.cell import RayTrainCell from miles.ray.train.cell_monitor import create_trainer_cell_health_checker from miles.utils.async_utils import AsyncioGatherUtils +from miles.utils.checksum_utils import flatten_inference_engine_checksums +from miles.utils.event_analyzer import analyzer as event_analyzer from miles.utils.event_logger.logger import get_event_logger, is_event_logger_initialized -from miles.utils.event_logger.models import CellReconfigureEvent, WitnessAllocateIdEvent +from miles.utils.event_logger.models import ( + CellReconfigureEvent, + InferenceEngineWeightChecksumEvent, + TrainGroupStepEndEvent, + WitnessAllocateIdEvent, +) from miles.utils.health_checker import NoopHealthChecker, SimpleHealthCheckerConfig from miles.utils.indep_dp import IndepDPInfo from miles.utils.megatron_args_utils import compute_megatron_world_size_except_dp @@ -123,6 +130,8 @@ def _create_cell(cell_index: int): async def train(self, rollout_id: int, rollout_data_pack): """Do one rollout training""" + event_analyzer.run_analysis_from_args(self.args) + async def _fn(attempt: int): witness_info = self._allocate_witness_info( rollout_id=rollout_id, @@ -141,6 +150,12 @@ async def _fn(attempt: int): ) self._check_train_one_attempt(snapshot_alive_cells, results) + self._log_step_end_event( + rollout_id=rollout_id, + snapshot_alive_cells=snapshot_alive_cells, + results=results, + ) + await retry(_fn) def _allocate_witness_info(self, *, rollout_id: int, attempt: int, sample_indices): @@ -162,6 +177,17 @@ def _allocate_witness_info(self, *, rollout_id: int, attempt: int, sample_indice return witness_info + def _log_step_end_event(self, *, rollout_id: int, snapshot_alive_cells: list, results: list): + if is_event_logger_initialized(): + cell_outcomes = { + cell.cell_index: ("error" if isinstance(cell_results, BaseException) else [r for r in cell_results]) + for cell, cell_results in zip(snapshot_alive_cells, results, strict=True) + } + get_event_logger().log( + TrainGroupStepEndEvent, + dict(rollout_id=rollout_id, cell_outcomes=cell_outcomes), + ) + @staticmethod def _check_train_one_attempt(snapshot_alive_cells, results): outcomes = RayTrainGroup._compute_attempt_outcomes(snapshot_alive_cells, results) @@ -227,6 +253,21 @@ async def update_weights(self, rollout_id: int | None = None): # Catch with vanilla retry: cells w/ exceptions are auto marked errored, thus retry will find the next one await retry(lambda _: self._execute_first_alive("update_weights", info=info)) + await self._maybe_log_inference_engine_weight_checksums(rollout_id=rollout_id) + + async def _maybe_log_inference_engine_weight_checksums(self, *, rollout_id: int | None) -> None: + if not is_event_logger_initialized(): + return + if self.args.debug_train_only or self.args.debug_rollout_only: + return + + check_weights_result = await self._rollout_manager.check_weights.remote("checksum") + engine_checksums = flatten_inference_engine_checksums(check_weights_result) + get_event_logger().log( + InferenceEngineWeightChecksumEvent, + dict(rollout_id=rollout_id, engine_checksums=engine_checksums), + ) + async def onload(self): # Catch *without* retry: cells w/ exceptions are auto marked errored, and will not be used await self._execute_all_alive_and_catch("wake_up") diff --git a/tests/fast/ray/train/test_group.py b/tests/fast/ray/train/test_group.py index 07b6ef7b2a..8c51493fb9 100644 --- a/tests/fast/ray/train/test_group.py +++ b/tests/fast/ray/train/test_group.py @@ -1,7 +1,7 @@ from collections.abc import Callable from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import ray @@ -939,3 +939,128 @@ def test_returns_witness_info_when_enabled(self): assert result is not None assert len(result.witness_ids) == 3 assert isinstance(result.stale_ids, list) + + +class TestLogStepEndEvent: + def test_with_normal_and_error_cells(self): + """Passes correct cell_outcomes to event logger for a mix of normal and errored cells.""" + group = _make_group(num_cells=3) + + mock_cell_0 = MagicMock() + mock_cell_0.cell_index = 0 + mock_cell_1 = MagicMock() + mock_cell_1.cell_index = 1 + mock_cell_2 = MagicMock() + mock_cell_2.cell_index = 2 + + snapshot_alive_cells = [mock_cell_0, mock_cell_1, mock_cell_2] + results = [ + [TrainStepOutcome.NORMAL, TrainStepOutcome.NORMAL], + RuntimeError("boom"), + [TrainStepOutcome.NORMAL], + ] + + with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True), patch( + "miles.ray.train.group.get_event_logger" + ) as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + group._log_step_end_event( + rollout_id=42, + snapshot_alive_cells=snapshot_alive_cells, + results=results, + ) + + mock_logger.log.assert_called_once() + args = mock_logger.log.call_args[0] + partial = args[1] + assert partial["rollout_id"] == 42 + + cell_outcomes = partial["cell_outcomes"] + assert cell_outcomes[0] == [TrainStepOutcome.NORMAL, TrainStepOutcome.NORMAL] + assert cell_outcomes[1] == "error" + assert cell_outcomes[2] == [TrainStepOutcome.NORMAL] + + +def _checksum_response(engine_checksums: list[dict[str, str]]) -> list: + """Build a nested servers->groups->engines check_weights('checksum') response.""" + engines = [ + {"success": True, "message": "ok", "ranks": [{"checksums": cs, "parallelism_info": {"rank": 0}}]} + for cs in engine_checksums + ] + return [[engines]] + + +class TestMaybeLogInferenceEngineWeightChecksums: + async def test_no_event_logger_does_not_call_check_weights(self): + """Without an initialized event logger, no check_weights request is issued.""" + rollout_mgr = MagicMock() + rollout_mgr.check_weights = MagicMock() + group = _make_group(num_cells=1, rollout_manager=rollout_mgr) + + with patch("miles.ray.train.group.is_event_logger_initialized", return_value=False): + await group._maybe_log_inference_engine_weight_checksums(rollout_id=0) + + rollout_mgr.check_weights.assert_not_called() + + async def test_none_rollout_id_logs_event(self): + """The initial out-of-loop sync (rollout_id=None) still logs an event with rollout_id=None.""" + rollout_mgr = MagicMock() + rollout_mgr.check_weights.remote = AsyncMock(return_value=_checksum_response([{"w": "e0"}])) + group = _make_group(num_cells=1, rollout_manager=rollout_mgr) + + with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True), patch( + "miles.ray.train.group.get_event_logger" + ) as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + await group._maybe_log_inference_engine_weight_checksums(rollout_id=None) + + mock_logger.log.assert_called_once() + logged = mock_logger.log.call_args.args[1] + assert logged == dict(rollout_id=None, engine_checksums=[{"rank0/w": "e0"}]) + + async def test_debug_train_only_skips_collection(self): + """Without real rollout engines (debug_train_only), no check_weights request is issued.""" + rollout_mgr = MagicMock() + rollout_mgr.check_weights = MagicMock() + group = _make_group(num_cells=1, rollout_manager=rollout_mgr) + group.args.debug_train_only = True + + with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True): + await group._maybe_log_inference_engine_weight_checksums(rollout_id=0) + + rollout_mgr.check_weights.assert_not_called() + + async def test_debug_rollout_only_skips_collection(self): + """Without real train engines pushing weights (debug_rollout_only), no check_weights request is issued.""" + rollout_mgr = MagicMock() + rollout_mgr.check_weights = MagicMock() + group = _make_group(num_cells=1, rollout_manager=rollout_mgr) + group.args.debug_rollout_only = True + + with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True): + await group._maybe_log_inference_engine_weight_checksums(rollout_id=0) + + rollout_mgr.check_weights.assert_not_called() + + async def test_enabled_logs_one_event_per_rollout(self): + """With event logger on and real engines, one event holds every engine's checksums.""" + rollout_mgr = MagicMock() + rollout_mgr.check_weights.remote = AsyncMock(return_value=_checksum_response([{"w": "e0"}, {"w": "e1"}])) + group = _make_group(num_cells=1, rollout_manager=rollout_mgr) + + with patch("miles.ray.train.group.is_event_logger_initialized", return_value=True), patch( + "miles.ray.train.group.get_event_logger" + ) as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + await group._maybe_log_inference_engine_weight_checksums(rollout_id=3) + + rollout_mgr.check_weights.remote.assert_awaited_once_with("checksum") + mock_logger.log.assert_called_once() + logged = mock_logger.log.call_args.args[1] + assert logged == dict(rollout_id=3, engine_checksums=[{"rank0/w": "e0"}, {"rank0/w": "e1"}])