Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions miles/ray/rollout/rollout_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
)
from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function
from miles.utils.environ import enable_experimental_rollout_refactor
from miles.utils.event_analyzer import analyzer as event_analyzer
from miles.utils.event_logger import checkpoint as event_logger_checkpoint
from miles.utils.health_monitor import RolloutHealthMonitor
from miles.utils.http_utils import init_http_client
from miles.utils.logging_utils import configure_logger
Expand All @@ -44,6 +46,7 @@ class RolloutManager:
"""The class to run rollout and convert rollout data to training data."""

def __init__(self, args, pg):
event_logger_checkpoint.restore(args)
configure_logger(args, source=RolloutManagerProcessIdentity())

self.pg = pg
Expand Down Expand Up @@ -90,12 +93,13 @@ def __init__(self, args, pg):
monitor = RolloutHealthMonitor(group, args)
monitor.start()
self._health_monitors.append(monitor)
self._ci_fault_injection_pending = self.args.ci_test # Flag for CI fault injection
self._ci_fault_injection_pending = self.args.ci_test and "rollout" in self.args.ft_components

# -------------------------- lifecycle -----------------------------
# TODO: may have a `async def init` here later

def dispose(self):
event_analyzer.run_analysis_from_args(self.args)
if self._metric_checker is not None:
self._metric_checker.dispose()
for monitor in self._health_monitors:
Expand Down Expand Up @@ -177,7 +181,9 @@ async def _get_rollout_data(self, rollout_id):
# -------------------------- checkpointing -----------------------------

def save(self, rollout_id):
self.data_source.save(rollout_id)
if self.args.rollout_global_dataset:
self.data_source.save(rollout_id)
event_logger_checkpoint.snapshot(self.args, rollout_id)

def load(self, rollout_id=None):
self.data_source.load(rollout_id)
Expand Down
42 changes: 42 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def reset_arg(parser, name, **kwargs):
parser.add_argument(name, **kwargs)


_FT_CHOICES = ["rollout", "train"]


def get_miles_extra_args_provider(add_custom_arguments=None):
def add_miles_arguments(parser):
# Ray
Expand Down Expand Up @@ -601,6 +604,14 @@ def add_fault_tolerance_arguments(parser):
default=False,
help="Enable fault tolerance. Use --ft-components to select which components.",
)
parser.add_argument(
"--ft-components",
nargs="+",
default=None,
choices=_FT_CHOICES,
help="FT components to enable (requires --use-fault-tolerance). "
"Choices: rollout, train. Default when omitted: rollout.",
)
parser.add_argument(
"--rollout-health-check-interval",
type=float,
Expand Down Expand Up @@ -2077,9 +2088,40 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]:
return eval_datasets


_FT_DEFAULT_COMPONENTS: list[str] = ["rollout"]


def _resolve_ft_components(args: argparse.Namespace) -> list[str]:
if not args.use_fault_tolerance:
if args.ft_components is not None:
logger.warning("--ft-components is ignored without --use-fault-tolerance")
return []
if args.ft_components is None:
return list(_FT_DEFAULT_COMPONENTS)
return list(args.ft_components)


def miles_validate_args(args):
args.ft_components = _resolve_ft_components(args)
args.eval_datasets = _resolve_eval_datasets(args)

if "train" in args.ft_components:
args.indep_dp = True
args.delay_split_train_data_by_dp = True
args.save_local_weight_checksum = True
args.enable_event_analyzer = True
args.enable_witness = True
args.non_persistent_ckpt_type = "local"
if getattr(args, "non_persistent_local_ckpt_dir", None) is None:
args.non_persistent_local_ckpt_dir = "/tmp/miles_local_ckpt"
# atomic: each rank saves independently, no collective communication.
# fully_parallel needs all_gather_object which hangs after ncclCommAbort in healing.
args.non_persistent_local_ckpt_algo = "atomic"
logger.info(
"train in ft_components. Auto set indep_dp=True, delay_split_train_data_by_dp=True, save_local_weight_checksum=True, enable_event_analyzer=True, enable_witness=True, non_persistent_ckpt_type='local', non_persistent_local_ckpt_algo=%r",
args.non_persistent_local_ckpt_algo,
)

if args.indep_dp:
assert (
args.train_backend == "megatron"
Expand Down
41 changes: 40 additions & 1 deletion tests/fast/utils/test_arguments.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import logging
import sys
from types import SimpleNamespace
from unittest.mock import patch

import pytest

from miles.utils.arguments import _maybe_apply_dumper_overrides, get_miles_extra_args_provider
from miles.utils.arguments import _maybe_apply_dumper_overrides, _resolve_ft_components, get_miles_extra_args_provider
from miles.utils.misc import function_registry

PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"]
Expand Down Expand Up @@ -141,3 +142,41 @@ def test_recompute_logprobs_via_prefill_flag_is_parsed():
args = parser.parse_args(["--recompute-logprobs-via-prefill"] + REQUIRED_ARGS)

assert args.recompute_logprobs_via_prefill is True


class TestResolveFtComponents:
def test_disabled_with_no_components_returns_empty_without_warning(self, caplog) -> None:
"""use_fault_tolerance off and no ft_components yields an empty list and no warning."""
args = SimpleNamespace(use_fault_tolerance=False, ft_components=None)
with caplog.at_level(logging.WARNING, logger="miles.utils.arguments"):
result = _resolve_ft_components(args)

assert result == []
assert not any("--ft-components is ignored" in record.message for record in caplog.records)

def test_disabled_with_components_returns_empty_and_warns(self, caplog) -> None:
"""use_fault_tolerance off but ft_components set returns empty list and logs an ignore warning."""
args = SimpleNamespace(use_fault_tolerance=False, ft_components=["train"])
with caplog.at_level(logging.WARNING, logger="miles.utils.arguments"):
result = _resolve_ft_components(args)

assert result == []
assert any(
"--ft-components is ignored without --use-fault-tolerance" in record.message for record in caplog.records
)

def test_enabled_with_no_components_returns_default(self) -> None:
"""use_fault_tolerance on with no ft_components falls back to the default ['rollout']."""
args = SimpleNamespace(use_fault_tolerance=True, ft_components=None)
result = _resolve_ft_components(args)

assert result == ["rollout"]

def test_enabled_with_components_returns_distinct_copy(self) -> None:
"""use_fault_tolerance on with ft_components returns an equal but distinct list copy."""
components = ["train", "rollout"]
args = SimpleNamespace(use_fault_tolerance=True, ft_components=components)
result = _resolve_ft_components(args)

assert result == ["train", "rollout"]
assert result is not components
Loading