@@ -94,6 +94,11 @@ def delete_runpath(run_path: str) -> None:
94
94
shutil .rmtree (run_path )
95
95
96
96
97
+ class _UserCancelled (Exception ):
98
+ def __str__ (self ) -> str :
99
+ return "Experiment cancelled by user"
100
+
101
+
97
102
class _LogAggregration (logging .Handler ):
98
103
def __init__ (self , messages : MutableSequence [str ]) -> None :
99
104
self .messages = messages
@@ -571,7 +576,7 @@ async def run_monitor(
571
576
# Allow track() to emit an EndEvent.
572
577
return False
573
578
elif type (event ) is EETerminated :
574
- logger .debug ("got terminator event" )
579
+ logger .debug ("got terminated event" )
575
580
576
581
if not self ._end_queue .empty ():
577
582
logger .debug ("Run model canceled - during evaluation" )
@@ -594,11 +599,12 @@ async def run_ensemble_evaluator_async(
594
599
run_args : list [RunArg ],
595
600
ensemble : Ensemble ,
596
601
ee_config : EvaluatorServerConfig ,
597
- ) -> list [int ]:
602
+ ) -> list [int ] | _UserCancelled :
598
603
if not self ._end_queue .empty ():
599
604
logger .debug ("Run model canceled - pre evaluation" )
600
605
self ._end_queue .get ()
601
- return []
606
+ return _UserCancelled ()
607
+
602
608
ee_ensemble = self ._build_ensemble (run_args , ensemble .experiment_id )
603
609
evaluator = EnsembleEvaluator (
604
610
ee_ensemble ,
@@ -610,7 +616,7 @@ async def run_ensemble_evaluator_async(
610
616
await evaluator ._server_started
611
617
if not (await self .run_monitor (ee_config , ensemble .iteration )):
612
618
await evaluator_task
613
- return []
619
+ return _UserCancelled ()
614
620
615
621
logger .debug ("observed that model was finished, waiting tasks completion..." )
616
622
# The model has finished, we indicate this by sending a DONE
@@ -620,7 +626,8 @@ async def run_ensemble_evaluator_async(
620
626
logger .debug ("Run model canceled - post evaluation" )
621
627
self ._end_queue .get ()
622
628
await evaluator_task
623
- return []
629
+ return _UserCancelled ()
630
+
624
631
await evaluator_task
625
632
ensemble .refresh_ensemble_state ()
626
633
@@ -633,11 +640,10 @@ def run_ensemble_evaluator(
633
640
run_args : list [RunArg ],
634
641
ensemble : Ensemble ,
635
642
ee_config : EvaluatorServerConfig ,
636
- ) -> list [int ]:
637
- successful_realizations = asyncio .run (
643
+ ) -> list [int ] | _UserCancelled :
644
+ return asyncio .run (
638
645
self .run_ensemble_evaluator_async (run_args , ensemble , ee_config )
639
646
)
640
- return successful_realizations
641
647
642
648
def _build_ensemble (
643
649
self ,
@@ -757,11 +763,16 @@ def _evaluate_and_postprocess(
757
763
"run_paths" : self .run_paths ,
758
764
},
759
765
)
760
- successful_realizations = self .run_ensemble_evaluator (
766
+ result = self .run_ensemble_evaluator (
761
767
run_args ,
762
768
ensemble ,
763
769
evaluator_server_config ,
764
770
)
771
+ if type (result ) is _UserCancelled :
772
+ self .active_realizations = [False for _ in self .active_realizations ]
773
+ raise result
774
+ successful_realizations = cast (list [int ], result )
775
+
765
776
starting_realizations = [real .iens for real in run_args if real .active ]
766
777
failed_realizations = list (
767
778
set (starting_realizations ) - set (successful_realizations )
0 commit comments