77import logging
88import os
99import time
10- from typing import Any , Callable , List , Literal , Sequence , Iterator
10+ from contextlib import contextmanager
11+ from dataclasses import dataclass
12+ from typing import Any , Callable , Iterator , List , Literal , Sequence
1113
1214import chz
1315import numpy as np
1416import tinker
1517import torch
18+
1619from tinker_cookbook import checkpoint_utils
1720from tinker_cookbook .completers import TinkerTokenCompleter
1821from tinker_cookbook .display import colorize_example
3942from tinker_cookbook .tokenizer_utils import Tokenizer
4043from tinker_cookbook .utils import logtree , ml_log
4144from tinker_cookbook .utils .misc_utils import safezip , split_list , timed
42- from tinker_cookbook .utils .trace import scope , trace_init , get_scope_context
43- from contextlib import contextmanager
44-
45+ from tinker_cookbook .utils .trace import get_scope_context , scope , trace_init
4546
4647logger = logging .getLogger (__name__ )
4748
@@ -354,7 +355,7 @@ async def do_sync_training_with_stream_minibatch(
354355 ):
355356 # Samplers will produce trajectory groups asynchronously,
356357 # and the trainer will consume them as soon as they are ready
357- trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | None ]()
358+ trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | Shutdown | None ]()
358359 env_group_builders_P = dataset .get_batch (i_batch )
359360
360361 @scope
@@ -393,17 +394,18 @@ async def trajectory_group_worker_task(
393394 )
394395
395396 # Run multiple optimizer substeps per training iteration
396- (
397- sampling_client ,
398- full_batch_metrics ,
399- ) = await do_train_step_streaming_and_get_sampling_client (
397+ streaming_result = await do_train_step_streaming_and_get_sampling_client (
400398 cfg ,
401399 i_batch ,
402400 trajectory_groups_queue ,
403401 training_client ,
404402 service_client ,
405403 tokenizer ,
406404 )
405+ if streaming_result is None :
406+ logger .info ("[do_sync_training_with_stream_minibatch] Received shutdown signal" )
407+ return
408+ sampling_client , full_batch_metrics = streaming_result
407409
408410 # Log metrics
409411 metrics .update (full_batch_metrics )
@@ -428,6 +430,22 @@ class WrappedTrajectoryGroup:
428430 metrics : dict [str , Any ] = chz .field (default_factory = dict )
429431
430432
433+ @dataclass
434+ class Shutdown :
435+ pass
436+
437+
438+ class AsyncCounter :
439+ def __init__ (self , start : int = 0 ):
440+ self .value = start
441+ self .lock = asyncio .Lock ()
442+
443+ async def decrement_and_get (self ) -> int :
444+ async with self .lock :
445+ self .value -= 1
446+ return self .value
447+
448+
431449@scope
432450async def do_async_training (
433451 start_batch : int ,
@@ -444,13 +462,12 @@ async def do_async_training(
444462 """Implements async off-policy training, capped at K steps off policy."""
445463 assert cfg .async_config is not None
446464
447- shutdown_event = asyncio .Event ()
448465 # We will have groups_per_batch worker generating rollouts, so cap the
449466 # queue size to be groups_per_batch.
450- env_group_builders_queue = asyncio .Queue [EnvGroupBuilder | None ](
467+ env_group_builders_queue = asyncio .Queue [EnvGroupBuilder | Shutdown ](
451468 maxsize = cfg .async_config .groups_per_batch
452469 )
453- trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | None ]()
470+ trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | Shutdown | None ]()
454471
455472 # Initial sampling client to use
456473 path_dict = await checkpoint_utils .save_checkpoint_async (
@@ -461,38 +478,50 @@ async def do_async_training(
461478 kind = "both" ,
462479 )
463480
481+ # Shutdowns can be triggered by the dataloader running out of data,
482+ # of the trianing loop ending early.
483+ # If the dataloader is out of data, we want to make sure all remaining samples
484+ # are processed before terminating.
485+ # If the training loop ends first, we need all loops to terminate immediately.
486+ training_loop_shutdown_done_event = asyncio .Event ()
487+ trajectory_group_worker_alive_counter = AsyncCounter (cfg .async_config .groups_per_batch )
488+
464489 # This will be updated by the training loop
465490 sampling_client = training_client .create_sampling_client (path_dict ["sampler_path" ])
466491 sampling_client_step = start_batch
467492 sampling_client_updated_event = asyncio .Event ()
468493 sampling_client_updated_event .set ()
469494
470- @scope
471- def shutdown_loops ():
472- """Trigger all loops to shutdown"""
473- shutdown_event .set ()
474- assert cfg .async_config is not None
475- for _ in range (cfg .async_config .groups_per_batch ):
476- env_group_builders_queue .put_nowait (None )
477- sampling_client_updated_event .set ()
478-
479495 @scope
480496 async def dataloader_loop ():
481497 """Gets the next set of env builders to run"""
482498 i_batch = start_batch
483- while not shutdown_event .is_set () and i_batch < end_batch :
499+ while not training_loop_shutdown_done_event .is_set () and i_batch < end_batch :
484500 env_group_builders_P = dataset .get_batch (i_batch )
485501 for env_group_builder in env_group_builders_P :
486502 await env_group_builders_queue .put (env_group_builder )
487503 i_batch += 1
488504
505+ # We are done with the data loader loop, enqueue sentinel values
506+ # to allow the trajectory group worker loops to terminate.
507+ logger .info ("[dataloader_loop] No more data, shutting down trajectory group worker loops" )
508+ if not training_loop_shutdown_done_event .is_set ():
509+ assert cfg .async_config is not None
510+ for _ in range (cfg .async_config .groups_per_batch ):
511+ await env_group_builders_queue .put (Shutdown ())
512+ logger .info ("[dataloader_loop] Data loader loop terminated" )
513+
489514 @scope
490515 async def trajectory_group_worker_loop ():
491516 """Generates trajectories for a single env builder"""
492- while not shutdown_event .is_set ():
517+ while not training_loop_shutdown_done_event .is_set ():
493518 env_group_builder = await env_group_builders_queue .get ()
494- if env_group_builder is None :
495- break
519+ match env_group_builder :
520+ case EnvGroupBuilder ():
521+ pass
522+ case Shutdown ():
523+ logger .info ("[trajectory_group_worker_loop] Received shutdown signal" )
524+ break
496525
497526 metrics = {}
498527 t_start = time .time ()
@@ -518,6 +547,14 @@ async def trajectory_group_worker_loop():
518547 metrics = metrics ,
519548 )
520549 )
550+ num_alive_workers = await trajectory_group_worker_alive_counter .decrement_and_get ()
551+ if num_alive_workers == 0 :
552+ # All workers are done, enqueue a sentinel to terminate the training loop
553+ logger .info (
554+ "[trajectory_group_worker_loop] Last worker terminated, shutting down training loop"
555+ )
556+ trajectory_groups_queue .put_nowait (Shutdown ())
557+ logger .info ("[trajectory_group_worker_loop] Trajectory group worker loop terminated" )
521558
522559 @scope
523560 async def training_loop ():
@@ -530,9 +567,6 @@ async def training_loop():
530567 i_batch = start_batch
531568 wrapped_trajectory_groups = []
532569 while i_batch < end_batch :
533- wrapped_trajectory_group = await trajectory_groups_queue .get ()
534- if wrapped_trajectory_group is None :
535- continue
536570
537571 @scope
538572 def filter_stale_trajectory_group (
@@ -567,10 +601,7 @@ def filter_stale_trajectory_group(
567601 nonlocal sampling_client
568602 nonlocal sampling_client_step
569603 if cfg .stream_minibatch_config is not None :
570- (
571- sampling_client ,
572- train_step_metrics ,
573- ) = await do_train_step_streaming_and_get_sampling_client (
604+ streaming_result = await do_train_step_streaming_and_get_sampling_client (
574605 cfg ,
575606 i_batch ,
576607 trajectory_groups_queue ,
@@ -579,7 +610,21 @@ def filter_stale_trajectory_group(
579610 tokenizer ,
580611 filter_stale_trajectory_group ,
581612 )
613+ if streaming_result is None :
614+ logger .info ("[training_loop] Received shutdown signal" )
615+ break
616+ sampling_client , train_step_metrics = streaming_result
582617 else :
618+ wrapped_trajectory_group = await trajectory_groups_queue .get ()
619+ match wrapped_trajectory_group :
620+ case WrappedTrajectoryGroup ():
621+ pass
622+ case Shutdown ():
623+ logger .info ("[training_loop] Received shutdown signal" )
624+ break
625+ case None :
626+ continue
627+
583628 if not filter_stale_trajectory_group (wrapped_trajectory_group ):
584629 continue
585630
@@ -618,15 +663,17 @@ def filter_stale_trajectory_group(
618663 i_batch += 1
619664 wrapped_trajectory_groups = []
620665
621- shutdown_loops ()
666+ training_loop_shutdown_done_event .set ()
667+ sampling_client_updated_event .set ()
668+ logger .info ("[training_loop] Training loop terminated" )
622669
623670 @scope
624671 async def evaluation_loop ():
625672 """Runs evals periodically"""
626673 if len (evaluators ) == 0 or cfg .eval_every == 0 :
627674 return
628675
629- while not shutdown_event .is_set ():
676+ while not training_loop_shutdown_done_event .is_set ():
630677 await sampling_client_updated_event .wait ()
631678 sampling_client_updated_event .clear ()
632679
@@ -643,6 +690,7 @@ async def evaluation_loop():
643690 metrics .update ({f"test/{ k } " : v for k , v in eval_metrics .items ()})
644691 metrics ["time/evaluation_loop/total" ] = time .time () - t_start
645692 ml_logger .log_metrics (metrics , step = sampling_client_eval_step )
693+ logger .info ("[evaluation_loop] Evaluation loop terminated" )
646694
647695 await asyncio .gather (
648696 asyncio .create_task (dataloader_loop (), name = "dataloader_loop" ),
@@ -787,12 +835,12 @@ async def compute_full_batch_metrics_and_get_sampling_client(
787835async def do_train_step_streaming_and_get_sampling_client (
788836 cfg : Config ,
789837 i_batch : int ,
790- trajectory_groups_queue : asyncio .Queue [WrappedTrajectoryGroup | None ],
838+ trajectory_groups_queue : asyncio .Queue [WrappedTrajectoryGroup | Shutdown | None ],
791839 training_client : tinker .TrainingClient ,
792840 service_client : tinker .ServiceClient ,
793841 tokenizer : Tokenizer ,
794842 trajectory_group_filter : Callable [[WrappedTrajectoryGroup | None ], bool ] = lambda _ : True ,
795- ) -> tuple [tinker .SamplingClient , dict [str , Any ]]:
843+ ) -> tuple [tinker .SamplingClient , dict [str , Any ]] | None :
796844 """
797845 As soon as we have enough trajectories for a minibatch, we will train on them.
798846 This allows us to overlap sampling and training.
@@ -825,8 +873,16 @@ async def do_train_step_streaming_and_get_sampling_client(
825873 i_minibatch = 0
826874 while i_minibatch < cfg .stream_minibatch_config .num_minibatches :
827875 wrapped_trajectory_group = await trajectory_groups_queue .get ()
828- if not trajectory_group_filter (wrapped_trajectory_group ):
829- continue
876+ match wrapped_trajectory_group :
877+ case WrappedTrajectoryGroup ():
878+ pass
879+ case Shutdown ():
880+ logger .info (
881+ "[do_train_step_streaming_and_get_sampling_client] Received shutdown signal"
882+ )
883+ return None
884+ case None :
885+ continue
830886 wrapped_trajectory_groups .append (wrapped_trajectory_group )
831887
832888 if len (wrapped_trajectory_groups ) < groups_per_minibatch :
0 commit comments