Skip to content

Commit 4ea9fd9

Browse files
committed
[tinker-cookbook] rl: avoid hanging in async runs when we run out of data
Previously, on async RL runs, we can hang in shutdown if we run out of data. Thi fixes it to ensure proper shutdown and that all data in queues are drained with the dataloader loop terminates first.
1 parent 5f5ce26 commit 4ea9fd9

File tree

1 file changed

+94
-38
lines changed

1 file changed

+94
-38
lines changed

tinker_cookbook/rl/train.py

Lines changed: 94 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import logging
88
import os
99
import time
10-
from typing import Any, Callable, List, Literal, Sequence, Iterator
10+
from contextlib import contextmanager
11+
from dataclasses import dataclass
12+
from typing import Any, Callable, Iterator, List, Literal, Sequence
1113

1214
import chz
1315
import numpy as np
1416
import tinker
1517
import torch
18+
1619
from tinker_cookbook import checkpoint_utils
1720
from tinker_cookbook.completers import TinkerTokenCompleter
1821
from tinker_cookbook.display import colorize_example
@@ -39,9 +42,7 @@
3942
from tinker_cookbook.tokenizer_utils import Tokenizer
4043
from tinker_cookbook.utils import logtree, ml_log
4144
from tinker_cookbook.utils.misc_utils import safezip, split_list, timed
42-
from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context
43-
from contextlib import contextmanager
44-
45+
from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init
4546

4647
logger = logging.getLogger(__name__)
4748

@@ -354,7 +355,7 @@ async def do_sync_training_with_stream_minibatch(
354355
):
355356
# Samplers will produce trajectory groups asynchronously,
356357
# and the trainer will consume them as soon as they are ready
357-
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()
358+
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]()
358359
env_group_builders_P = dataset.get_batch(i_batch)
359360

360361
@scope
@@ -393,17 +394,18 @@ async def trajectory_group_worker_task(
393394
)
394395

395396
# Run multiple optimizer substeps per training iteration
396-
(
397-
sampling_client,
398-
full_batch_metrics,
399-
) = await do_train_step_streaming_and_get_sampling_client(
397+
streaming_result = await do_train_step_streaming_and_get_sampling_client(
400398
cfg,
401399
i_batch,
402400
trajectory_groups_queue,
403401
training_client,
404402
service_client,
405403
tokenizer,
406404
)
405+
if streaming_result is None:
406+
logger.info("[do_sync_training_with_stream_minibatch] Received shutdown signal")
407+
return
408+
sampling_client, full_batch_metrics = streaming_result
407409

408410
# Log metrics
409411
metrics.update(full_batch_metrics)
@@ -428,6 +430,22 @@ class WrappedTrajectoryGroup:
428430
metrics: dict[str, Any] = chz.field(default_factory=dict)
429431

430432

433+
@dataclass
434+
class Shutdown:
435+
pass
436+
437+
438+
class AsyncCounter:
439+
def __init__(self, start: int = 0):
440+
self.value = start
441+
self.lock = asyncio.Lock()
442+
443+
async def decrement_and_get(self) -> int:
444+
async with self.lock:
445+
self.value -= 1
446+
return self.value
447+
448+
431449
@scope
432450
async def do_async_training(
433451
start_batch: int,
@@ -444,13 +462,12 @@ async def do_async_training(
444462
"""Implements async off-policy training, capped at K steps off policy."""
445463
assert cfg.async_config is not None
446464

447-
shutdown_event = asyncio.Event()
448465
# We will have groups_per_batch worker generating rollouts, so cap the
449466
# queue size to be groups_per_batch.
450-
env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | None](
467+
env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | Shutdown](
451468
maxsize=cfg.async_config.groups_per_batch
452469
)
453-
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()
470+
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]()
454471

455472
# Initial sampling client to use
456473
path_dict = await checkpoint_utils.save_checkpoint_async(
@@ -461,38 +478,50 @@ async def do_async_training(
461478
kind="both",
462479
)
463480

481+
# Shutdowns can be triggered by the dataloader running out of data,
482+
# of the trianing loop ending early.
483+
# If the dataloader is out of data, we want to make sure all remaining samples
484+
# are processed before terminating.
485+
# If the training loop ends first, we need all loops to terminate immediately.
486+
training_loop_shutdown_done_event = asyncio.Event()
487+
trajectory_group_worker_alive_counter = AsyncCounter(cfg.async_config.groups_per_batch)
488+
464489
# This will be updated by the training loop
465490
sampling_client = training_client.create_sampling_client(path_dict["sampler_path"])
466491
sampling_client_step = start_batch
467492
sampling_client_updated_event = asyncio.Event()
468493
sampling_client_updated_event.set()
469494

470-
@scope
471-
def shutdown_loops():
472-
"""Trigger all loops to shutdown"""
473-
shutdown_event.set()
474-
assert cfg.async_config is not None
475-
for _ in range(cfg.async_config.groups_per_batch):
476-
env_group_builders_queue.put_nowait(None)
477-
sampling_client_updated_event.set()
478-
479495
@scope
480496
async def dataloader_loop():
481497
"""Gets the next set of env builders to run"""
482498
i_batch = start_batch
483-
while not shutdown_event.is_set() and i_batch < end_batch:
499+
while not training_loop_shutdown_done_event.is_set() and i_batch < end_batch:
484500
env_group_builders_P = dataset.get_batch(i_batch)
485501
for env_group_builder in env_group_builders_P:
486502
await env_group_builders_queue.put(env_group_builder)
487503
i_batch += 1
488504

505+
# We are done with the data loader loop, enqueue sentinel values
506+
# to allow the trajectory group worker loops to terminate.
507+
logger.info("[dataloader_loop] No more data, shutting down trajectory group worker loops")
508+
if not training_loop_shutdown_done_event.is_set():
509+
assert cfg.async_config is not None
510+
for _ in range(cfg.async_config.groups_per_batch):
511+
await env_group_builders_queue.put(Shutdown())
512+
logger.info("[dataloader_loop] Data loader loop terminated")
513+
489514
@scope
490515
async def trajectory_group_worker_loop():
491516
"""Generates trajectories for a single env builder"""
492-
while not shutdown_event.is_set():
517+
while not training_loop_shutdown_done_event.is_set():
493518
env_group_builder = await env_group_builders_queue.get()
494-
if env_group_builder is None:
495-
break
519+
match env_group_builder:
520+
case EnvGroupBuilder():
521+
pass
522+
case Shutdown():
523+
logger.info("[trajectory_group_worker_loop] Received shutdown signal")
524+
break
496525

497526
metrics = {}
498527
t_start = time.time()
@@ -518,6 +547,14 @@ async def trajectory_group_worker_loop():
518547
metrics=metrics,
519548
)
520549
)
550+
num_alive_workers = await trajectory_group_worker_alive_counter.decrement_and_get()
551+
if num_alive_workers == 0:
552+
# All workers are done, enqueue a sentinel to terminate the training loop
553+
logger.info(
554+
"[trajectory_group_worker_loop] Last worker terminated, shutting down training loop"
555+
)
556+
trajectory_groups_queue.put_nowait(Shutdown())
557+
logger.info("[trajectory_group_worker_loop] Trajectory group worker loop terminated")
521558

522559
@scope
523560
async def training_loop():
@@ -530,9 +567,6 @@ async def training_loop():
530567
i_batch = start_batch
531568
wrapped_trajectory_groups = []
532569
while i_batch < end_batch:
533-
wrapped_trajectory_group = await trajectory_groups_queue.get()
534-
if wrapped_trajectory_group is None:
535-
continue
536570

537571
@scope
538572
def filter_stale_trajectory_group(
@@ -567,10 +601,7 @@ def filter_stale_trajectory_group(
567601
nonlocal sampling_client
568602
nonlocal sampling_client_step
569603
if cfg.stream_minibatch_config is not None:
570-
(
571-
sampling_client,
572-
train_step_metrics,
573-
) = await do_train_step_streaming_and_get_sampling_client(
604+
streaming_result = await do_train_step_streaming_and_get_sampling_client(
574605
cfg,
575606
i_batch,
576607
trajectory_groups_queue,
@@ -579,7 +610,21 @@ def filter_stale_trajectory_group(
579610
tokenizer,
580611
filter_stale_trajectory_group,
581612
)
613+
if streaming_result is None:
614+
logger.info("[training_loop] Received shutdown signal")
615+
break
616+
sampling_client, train_step_metrics = streaming_result
582617
else:
618+
wrapped_trajectory_group = await trajectory_groups_queue.get()
619+
match wrapped_trajectory_group:
620+
case WrappedTrajectoryGroup():
621+
pass
622+
case Shutdown():
623+
logger.info("[training_loop] Received shutdown signal")
624+
break
625+
case None:
626+
continue
627+
583628
if not filter_stale_trajectory_group(wrapped_trajectory_group):
584629
continue
585630

@@ -618,15 +663,17 @@ def filter_stale_trajectory_group(
618663
i_batch += 1
619664
wrapped_trajectory_groups = []
620665

621-
shutdown_loops()
666+
training_loop_shutdown_done_event.set()
667+
sampling_client_updated_event.set()
668+
logger.info("[training_loop] Training loop terminated")
622669

623670
@scope
624671
async def evaluation_loop():
625672
"""Runs evals periodically"""
626673
if len(evaluators) == 0 or cfg.eval_every == 0:
627674
return
628675

629-
while not shutdown_event.is_set():
676+
while not training_loop_shutdown_done_event.is_set():
630677
await sampling_client_updated_event.wait()
631678
sampling_client_updated_event.clear()
632679

@@ -643,6 +690,7 @@ async def evaluation_loop():
643690
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
644691
metrics["time/evaluation_loop/total"] = time.time() - t_start
645692
ml_logger.log_metrics(metrics, step=sampling_client_eval_step)
693+
logger.info("[evaluation_loop] Evaluation loop terminated")
646694

647695
await asyncio.gather(
648696
asyncio.create_task(dataloader_loop(), name="dataloader_loop"),
@@ -787,12 +835,12 @@ async def compute_full_batch_metrics_and_get_sampling_client(
787835
async def do_train_step_streaming_and_get_sampling_client(
788836
cfg: Config,
789837
i_batch: int,
790-
trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | None],
838+
trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None],
791839
training_client: tinker.TrainingClient,
792840
service_client: tinker.ServiceClient,
793841
tokenizer: Tokenizer,
794842
trajectory_group_filter: Callable[[WrappedTrajectoryGroup | None], bool] = lambda _: True,
795-
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
843+
) -> tuple[tinker.SamplingClient, dict[str, Any]] | None:
796844
"""
797845
As soon as we have enough trajectories for a minibatch, we will train on them.
798846
This allows us to overlap sampling and training.
@@ -825,8 +873,16 @@ async def do_train_step_streaming_and_get_sampling_client(
825873
i_minibatch = 0
826874
while i_minibatch < cfg.stream_minibatch_config.num_minibatches:
827875
wrapped_trajectory_group = await trajectory_groups_queue.get()
828-
if not trajectory_group_filter(wrapped_trajectory_group):
829-
continue
876+
match wrapped_trajectory_group:
877+
case WrappedTrajectoryGroup():
878+
pass
879+
case Shutdown():
880+
logger.info(
881+
"[do_train_step_streaming_and_get_sampling_client] Received shutdown signal"
882+
)
883+
return None
884+
case None:
885+
continue
830886
wrapped_trajectory_groups.append(wrapped_trajectory_group)
831887

832888
if len(wrapped_trajectory_groups) < groups_per_minibatch:

0 commit comments

Comments
 (0)