diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index a88e2760f6..0eba2a6767 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -667,6 +667,12 @@ def add_data_arguments(parser): default=None, help="Number of rollout steps. If not set, we will calculate the number of rollout steps from the dataset size.", ) + parser.add_argument( + "--debug-exit-after-rollout", + type=int, + default=None, + help="Exit training after this many rollouts (for testing checkpoint resume with consistent scheduler params).", + ) parser.add_argument( "--num-epoch", type=int, diff --git a/train.py b/train.py index bd83f51f5a..e319d918ef 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import asyncio +import logging from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS @@ -12,6 +13,8 @@ from miles.utils.process_identity import MainProcessIdentity from miles.utils.tracking_utils import finish_tracking, init_tracking +logger = logging.getLogger(__name__) + async def train(args): configure_logger(args, source=MainProcessIdentity()) @@ -113,6 +116,17 @@ async def save(rollout_id): if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch): await rollout_manager.eval.remote(rollout_id) + if ( + args.debug_exit_after_rollout is not None + and (rollout_id - args.start_rollout_id + 1) >= args.debug_exit_after_rollout + ): + logger.info( + "debug_exit_after_rollout=%d reached at rollout_id=%d, exiting", + args.debug_exit_after_rollout, + rollout_id, + ) + break + await rollout_manager.dispose.remote() diff --git a/train_async.py b/train_async.py index a255327b4a..7e60518b81 100644 --- a/train_async.py +++ b/train_async.py @@ -1,4 +1,5 @@ import asyncio +import logging from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from miles.utils.arguments import parse_args @@ -10,6 +11,8 @@ from miles.utils.process_identity import MainProcessIdentity from miles.utils.tracking_utils import finish_tracking, init_tracking +logger = logging.getLogger(__name__) + # The framework supports other asynchronous approaches such as fully async (which is shown in examples/full_async). async def train(args): @@ -85,6 +88,17 @@ async def train(args): if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch): await rollout_manager.eval.remote(rollout_id) + if ( + args.debug_exit_after_rollout is not None + and (rollout_id - args.start_rollout_id + 1) >= args.debug_exit_after_rollout + ): + logger.info( + "debug_exit_after_rollout=%d reached at rollout_id=%d, exiting", + args.debug_exit_after_rollout, + rollout_id, + ) + break + await rollout_manager.dispose.remote()