Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,12 @@ def add_data_arguments(parser):
default=None,
help="Number of rollout steps. If not set, we will calculate the number of rollout steps from the dataset size.",
)
parser.add_argument(
"--debug-exit-after-rollout",
type=int,
default=None,
help="Exit training after this many rollouts (for testing checkpoint resume with consistent scheduler params).",
)
parser.add_argument(
"--num-epoch",
type=int,
Expand Down
14 changes: 14 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging

from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS

Expand All @@ -12,6 +13,8 @@
from miles.utils.process_identity import MainProcessIdentity
from miles.utils.tracking_utils import finish_tracking, init_tracking

logger = logging.getLogger(__name__)


async def train(args):
configure_logger(args, source=MainProcessIdentity())
Expand Down Expand Up @@ -113,6 +116,17 @@ async def save(rollout_id):
if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch):
await rollout_manager.eval.remote(rollout_id)

if (
args.debug_exit_after_rollout is not None
and (rollout_id - args.start_rollout_id + 1) >= args.debug_exit_after_rollout
):
logger.info(
"debug_exit_after_rollout=%d reached at rollout_id=%d, exiting",
args.debug_exit_after_rollout,
rollout_id,
)
break

await rollout_manager.dispose.remote()


Expand Down
14 changes: 14 additions & 0 deletions train_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging

from miles.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
from miles.utils.arguments import parse_args
Expand All @@ -10,6 +11,8 @@
from miles.utils.process_identity import MainProcessIdentity
from miles.utils.tracking_utils import finish_tracking, init_tracking

logger = logging.getLogger(__name__)


# The framework supports other asynchronous approaches such as fully async (which is shown in examples/full_async).
async def train(args):
Expand Down Expand Up @@ -85,6 +88,17 @@ async def train(args):
if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch):
await rollout_manager.eval.remote(rollout_id)

if (
args.debug_exit_after_rollout is not None
and (rollout_id - args.start_rollout_id + 1) >= args.debug_exit_after_rollout
):
logger.info(
"debug_exit_after_rollout=%d reached at rollout_id=%d, exiting",
args.debug_exit_after_rollout,
rollout_id,
)
break

await rollout_manager.dispose.remote()


Expand Down
Loading