Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,30 @@ def add_fault_tolerance_arguments(parser):
default=0,
help="Initial grace period (in seconds) before starting health checks. This allows time for model compilation and initialization. Increase this value significantly when using deepgemm.",
)
parser.add_argument(
"--control-server-port",
type=int,
default=0,
help="Port for HTTP control server. 0 = disabled.",
)
parser.add_argument(
"--mini-ft-controller-enable",
action="store_true",
default=False,
help="Enable the mini fault-tolerance controller that auto-heals Fatal cells.",
)
parser.add_argument(
"--mini-ft-controller-poll-interval",
type=float,
default=10.0,
help="Interval in seconds between cell health polls.",
)
parser.add_argument(
"--mini-ft-controller-resume-delay",
type=float,
default=10.0,
help="Delay in seconds between suspending and resuming a cell during heal.",
)
SimpleHealthCheckerConfig.add_arguments(parser, prefix="trainer-heartbeat-checker")
return parser

Expand Down Expand Up @@ -2105,6 +2129,9 @@ def miles_validate_args(args):
args.ft_components = _resolve_ft_components(args)
args.eval_datasets = _resolve_eval_datasets(args)

if args.mini_ft_controller_enable and args.control_server_port == 0:
raise ValueError("--mini-ft-controller-enable requires --control-server-port to be set (non-zero)")

if "train" in args.ft_components:
args.indep_dp = True
args.delay_split_train_data_by_dp = True
Expand Down
21 changes: 21 additions & 0 deletions tests/fast/utils/test_mini_ft_controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import argparse
import asyncio
import json
from typing import Any
Expand Down Expand Up @@ -645,3 +646,23 @@ async def test_resume_sends_correct_patch(self) -> None:
assert call_args[0][0] == "/api/v1/cells/actor-0"
body = json.loads(call_args[1]["content"])
assert body == {"spec": {"suspend": False}}


class TestArgumentValidation:
def test_requires_control_server_port(self) -> None:
"""mini_ft_controller_enable=True + control_server_port=0 → error."""
from miles.utils.arguments import miles_validate_args

args = argparse.Namespace(
mini_ft_controller_enable=True,
control_server_port=0,
use_fault_tolerance=False,
ft_components=None,
eval_datasets=None,
eval_data=None,
eval_config=None,
eval_prompt_data=None,
)

with pytest.raises(ValueError, match="--mini-ft-controller-enable requires --control-server-port"):
miles_validate_args(args)
12 changes: 12 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
from miles.utils.arguments import parse_args
from miles.utils.async_utils import eager_create_task
from miles.utils.control_server.server import start_control_server
from miles.utils.logging_utils import configure_logger
from miles.utils.mini_ft_controller import maybe_start_mini_ft_controller
from miles.utils.misc import should_run_periodic_action
from miles.utils.process_identity import MainProcessIdentity
from miles.utils.tracking_utils import finish_tracking, init_tracking
Expand All @@ -24,6 +26,16 @@ async def train(args):
# create the actor and critic models
actor_model, critic_model = await create_training_models(args, pgs, rollout_manager)

if args.control_server_port:
start_control_server(
actor_model=actor_model,
rollout_manager=rollout_manager,
port=args.control_server_port,
ft_components=args.ft_components,
)

maybe_start_mini_ft_controller(args)

if args.offload_rollout:
await rollout_manager.onload_weights.remote()

Expand Down
12 changes: 12 additions & 0 deletions train_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
from miles.utils.arguments import parse_args
from miles.utils.async_utils import eager_create_task
from miles.utils.control_server.server import start_control_server
from miles.utils.logging_utils import configure_logger
from miles.utils.mini_ft_controller import maybe_start_mini_ft_controller
from miles.utils.misc import should_run_periodic_action
from miles.utils.process_identity import MainProcessIdentity
from miles.utils.tracking_utils import finish_tracking, init_tracking
Expand All @@ -24,6 +26,16 @@ async def train(args):
# create the actor and critic models
actor_model, critic_model = await create_training_models(args, pgs, rollout_manager)

if args.control_server_port:
start_control_server(
actor_model=actor_model,
rollout_manager=rollout_manager,
port=args.control_server_port,
ft_components=args.ft_components,
)

maybe_start_mini_ft_controller(args)

# always update weight first so that sglang has the loaded weights from training.
await actor_model.update_weights()

Expand Down
Loading