diff --git a/miles/ray/rollout/rollout_manager.py b/miles/ray/rollout/rollout_manager.py index afd5828489..3a9962f845 100644 --- a/miles/ray/rollout/rollout_manager.py +++ b/miles/ray/rollout/rollout_manager.py @@ -23,6 +23,8 @@ ) from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils.environ import enable_experimental_rollout_refactor +from miles.utils.event_analyzer import analyzer as event_analyzer +from miles.utils.event_logger import checkpoint as event_logger_checkpoint from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import init_http_client from miles.utils.logging_utils import configure_logger @@ -44,6 +46,7 @@ class RolloutManager: """The class to run rollout and convert rollout data to training data.""" def __init__(self, args, pg): + event_logger_checkpoint.restore(args) configure_logger(args, source=RolloutManagerProcessIdentity()) self.pg = pg @@ -90,12 +93,13 @@ def __init__(self, args, pg): monitor = RolloutHealthMonitor(group, args) monitor.start() self._health_monitors.append(monitor) - self._ci_fault_injection_pending = self.args.ci_test # Flag for CI fault injection + self._ci_fault_injection_pending = self.args.ci_test and "rollout" in self.args.ft_components # -------------------------- lifecycle ----------------------------- # TODO: may have a `async def init` here later def dispose(self): + event_analyzer.run_analysis_from_args(self.args) if self._metric_checker is not None: self._metric_checker.dispose() for monitor in self._health_monitors: @@ -177,7 +181,9 @@ async def _get_rollout_data(self, rollout_id): # -------------------------- checkpointing ----------------------------- def save(self, rollout_id): - self.data_source.save(rollout_id) + if self.args.rollout_global_dataset: + self.data_source.save(rollout_id) + event_logger_checkpoint.snapshot(self.args, rollout_id) def load(self, rollout_id=None): self.data_source.load(rollout_id) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index b5500ded9e..035ad3114b 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -37,6 +37,9 @@ def reset_arg(parser, name, **kwargs): parser.add_argument(name, **kwargs) +_FT_CHOICES = ["rollout", "train"] + + def get_miles_extra_args_provider(add_custom_arguments=None): def add_miles_arguments(parser): # Ray @@ -601,6 +604,14 @@ def add_fault_tolerance_arguments(parser): default=False, help="Enable fault tolerance. Use --ft-components to select which components.", ) + parser.add_argument( + "--ft-components", + nargs="+", + default=None, + choices=_FT_CHOICES, + help="FT components to enable (requires --use-fault-tolerance). " + "Choices: rollout, train. Default when omitted: rollout.", + ) parser.add_argument( "--rollout-health-check-interval", type=float, @@ -2077,9 +2088,40 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: return eval_datasets +_FT_DEFAULT_COMPONENTS: list[str] = ["rollout"] + + +def _resolve_ft_components(args: argparse.Namespace) -> list[str]: + if not args.use_fault_tolerance: + if args.ft_components is not None: + logger.warning("--ft-components is ignored without --use-fault-tolerance") + return [] + if args.ft_components is None: + return list(_FT_DEFAULT_COMPONENTS) + return list(args.ft_components) + + def miles_validate_args(args): + args.ft_components = _resolve_ft_components(args) args.eval_datasets = _resolve_eval_datasets(args) + if "train" in args.ft_components: + args.indep_dp = True + args.delay_split_train_data_by_dp = True + args.save_local_weight_checksum = True + args.enable_event_analyzer = True + args.enable_witness = True + args.non_persistent_ckpt_type = "local" + if getattr(args, "non_persistent_local_ckpt_dir", None) is None: + args.non_persistent_local_ckpt_dir = "/tmp/miles_local_ckpt" + # atomic: each rank saves independently, no collective communication. + # fully_parallel needs all_gather_object which hangs after ncclCommAbort in healing. + args.non_persistent_local_ckpt_algo = "atomic" + logger.info( + "train in ft_components. Auto set indep_dp=True, delay_split_train_data_by_dp=True, save_local_weight_checksum=True, enable_event_analyzer=True, enable_witness=True, non_persistent_ckpt_type='local', non_persistent_local_ckpt_algo=%r", + args.non_persistent_local_ckpt_algo, + ) + if args.indep_dp: assert ( args.train_backend == "megatron" diff --git a/tests/fast/utils/test_arguments.py b/tests/fast/utils/test_arguments.py index 5fb4345480..c79ae967f9 100644 --- a/tests/fast/utils/test_arguments.py +++ b/tests/fast/utils/test_arguments.py @@ -1,11 +1,12 @@ import argparse +import logging import sys from types import SimpleNamespace from unittest.mock import patch import pytest -from miles.utils.arguments import _maybe_apply_dumper_overrides, get_miles_extra_args_provider +from miles.utils.arguments import _maybe_apply_dumper_overrides, _resolve_ft_components, get_miles_extra_args_provider from miles.utils.misc import function_registry PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"] @@ -141,3 +142,41 @@ def test_recompute_logprobs_via_prefill_flag_is_parsed(): args = parser.parse_args(["--recompute-logprobs-via-prefill"] + REQUIRED_ARGS) assert args.recompute_logprobs_via_prefill is True + + +class TestResolveFtComponents: + def test_disabled_with_no_components_returns_empty_without_warning(self, caplog) -> None: + """use_fault_tolerance off and no ft_components yields an empty list and no warning.""" + args = SimpleNamespace(use_fault_tolerance=False, ft_components=None) + with caplog.at_level(logging.WARNING, logger="miles.utils.arguments"): + result = _resolve_ft_components(args) + + assert result == [] + assert not any("--ft-components is ignored" in record.message for record in caplog.records) + + def test_disabled_with_components_returns_empty_and_warns(self, caplog) -> None: + """use_fault_tolerance off but ft_components set returns empty list and logs an ignore warning.""" + args = SimpleNamespace(use_fault_tolerance=False, ft_components=["train"]) + with caplog.at_level(logging.WARNING, logger="miles.utils.arguments"): + result = _resolve_ft_components(args) + + assert result == [] + assert any( + "--ft-components is ignored without --use-fault-tolerance" in record.message for record in caplog.records + ) + + def test_enabled_with_no_components_returns_default(self) -> None: + """use_fault_tolerance on with no ft_components falls back to the default ['rollout'].""" + args = SimpleNamespace(use_fault_tolerance=True, ft_components=None) + result = _resolve_ft_components(args) + + assert result == ["rollout"] + + def test_enabled_with_components_returns_distinct_copy(self) -> None: + """use_fault_tolerance on with ft_components returns an equal but distinct list copy.""" + components = ["train", "rollout"] + args = SimpleNamespace(use_fault_tolerance=True, ft_components=components) + result = _resolve_ft_components(args) + + assert result == ["train", "rollout"] + assert result is not components