@@ -94,6 +94,10 @@ def delete_runpath(run_path: str) -> None:
94
94
shutil .rmtree (run_path )
95
95
96
96
97
+ class _UserCancelled (Exception ):
98
+ pass
99
+
100
+
97
101
class _LogAggregration (logging .Handler ):
98
102
def __init__ (self , messages : MutableSequence [str ]) -> None :
99
103
self .messages = messages
@@ -571,7 +575,7 @@ async def run_monitor(
571
575
# Allow track() to emit an EndEvent.
572
576
return False
573
577
elif type (event ) is EETerminated :
574
- logger .debug ("got terminator event" )
578
+ logger .debug ("got terminated event" )
575
579
576
580
if not self ._end_queue .empty ():
577
581
logger .debug ("Run model canceled - during evaluation" )
@@ -594,11 +598,12 @@ async def run_ensemble_evaluator_async(
594
598
run_args : list [RunArg ],
595
599
ensemble : Ensemble ,
596
600
ee_config : EvaluatorServerConfig ,
597
- ) -> list [int ]:
601
+ ) -> list [int ] | _UserCancelled :
598
602
if not self ._end_queue .empty ():
599
603
logger .debug ("Run model canceled - pre evaluation" )
600
604
self ._end_queue .get ()
601
- return []
605
+ return _UserCancelled ()
606
+
602
607
ee_ensemble = self ._build_ensemble (run_args , ensemble .experiment_id )
603
608
evaluator = EnsembleEvaluator (
604
609
ee_ensemble ,
@@ -610,7 +615,7 @@ async def run_ensemble_evaluator_async(
610
615
await evaluator ._server_started
611
616
if not (await self .run_monitor (ee_config , ensemble .iteration )):
612
617
await evaluator_task
613
- return []
618
+ return _UserCancelled ()
614
619
615
620
logger .debug ("observed that model was finished, waiting tasks completion..." )
616
621
# The model has finished, we indicate this by sending a DONE
@@ -620,7 +625,8 @@ async def run_ensemble_evaluator_async(
620
625
logger .debug ("Run model canceled - post evaluation" )
621
626
self ._end_queue .get ()
622
627
await evaluator_task
623
- return []
628
+ return _UserCancelled ()
629
+
624
630
await evaluator_task
625
631
ensemble .refresh_ensemble_state ()
626
632
@@ -633,11 +639,10 @@ def run_ensemble_evaluator(
633
639
run_args : list [RunArg ],
634
640
ensemble : Ensemble ,
635
641
ee_config : EvaluatorServerConfig ,
636
- ) -> list [int ]:
637
- successful_realizations = asyncio .run (
642
+ ) -> list [int ] | _UserCancelled :
643
+ return asyncio .run (
638
644
self .run_ensemble_evaluator_async (run_args , ensemble , ee_config )
639
645
)
640
- return successful_realizations
641
646
642
647
def _build_ensemble (
643
648
self ,
@@ -757,11 +762,16 @@ def _evaluate_and_postprocess(
757
762
"run_paths" : self .run_paths ,
758
763
},
759
764
)
760
- successful_realizations = self .run_ensemble_evaluator (
765
+ result = self .run_ensemble_evaluator (
761
766
run_args ,
762
767
ensemble ,
763
768
evaluator_server_config ,
764
769
)
770
+ if type (result ) is _UserCancelled :
771
+ self .active_realizations = [False for _ in self .active_realizations ]
772
+ raise _UserCancelled ("Experiment cancelled by user" )
773
+ successful_realizations = cast (list [int ], result )
774
+
765
775
starting_realizations = [real .iens for real in run_args if real .active ]
766
776
failed_realizations = list (
767
777
set (starting_realizations ) - set (successful_realizations )
0 commit comments