4
4
import logging
5
5
import os
6
6
import traceback
7
+ from collections .abc import Callable
7
8
from dataclasses import dataclass , field
8
9
from functools import partial
9
10
from pathlib import Path
24
25
@dataclass
25
26
class OptimalResult :
26
27
batch : int
27
- controls : list [ Any ]
28
+ controls : dict [ str , Any ]
28
29
total_objective : float
29
30
30
31
@@ -155,7 +156,7 @@ def read_from_experiment(self, experiment: _OptimizerOnlyExperiment) -> None:
155
156
self .batches .append (
156
157
BatchStorageData (
157
158
batch_id = info ["batch_id" ],
158
- ** {
159
+ ** { # type: ignore
159
160
df_name : try_read_df (
160
161
Path (ens .optimizer_mount_point ) / f"{ df_name } .parquet"
161
162
)
@@ -198,8 +199,8 @@ class _OptimizerOnlyExperiment:
198
199
"""
199
200
200
201
def __init__ (self , output_dir : Path ) -> None :
201
- self ._output_dir = output_dir
202
- self ._ensembles = {}
202
+ self ._output_dir : Path = output_dir
203
+ self ._ensembles : dict [ str , _OptimizerOnlyEnsemble ] = {}
203
204
204
205
@property
205
206
def optimizer_mount_point (self ) -> Path :
@@ -329,7 +330,7 @@ def _ropt_to_df(
329
330
field : str ,
330
331
* ,
331
332
values : list [str ],
332
- select : list ,
333
+ select : list [ str ] ,
333
334
) -> pl .DataFrame :
334
335
df = pl .from_pandas (
335
336
results .to_dataframe (field , select = values ).reset_index (),
@@ -340,17 +341,23 @@ def _ropt_to_df(
340
341
# retrieved from the everest configuration and were stored in the init
341
342
# method. Here we replace the indices with those names:
342
343
ropt_to_everest_names = {
343
- "variable" : self .data .controls ["control_name" ],
344
- "objective" : self .data .objective_functions ["objective_name" ],
344
+ "variable" : self .data .controls ["control_name" ]
345
+ if self .data .controls is not None
346
+ else None ,
347
+ "objective" : self .data .objective_functions ["objective_name" ]
348
+ if self .data .objective_functions is not None
349
+ else None ,
345
350
"nonlinear_constraint" : (
346
351
self .data .nonlinear_constraints ["constraint_name" ]
347
352
if self .data .nonlinear_constraints is not None
348
353
else None
349
354
),
350
- "realization" : self .data .realization_weights ["realization" ],
355
+ "realization" : self .data .realization_weights ["realization" ]
356
+ if self .data .realization_weights is not None
357
+ else None ,
351
358
}
352
359
df = df .with_columns (
353
- pl .col (ropt_name ).replace_strict (dict (enumerate (everest_names )))
360
+ pl .col (ropt_name ).replace_strict (dict (enumerate (everest_names ))) # type: ignore
354
361
for ropt_name , everest_names in ropt_to_everest_names .items ()
355
362
if ropt_name in select
356
363
)
@@ -367,7 +374,7 @@ def write_to_output_dir(self) -> None:
367
374
self .data .write_to_experiment (exp )
368
375
369
376
@staticmethod
370
- def check_for_deprecated_seba_storage (config_file : str ):
377
+ def check_for_deprecated_seba_storage (config_file : str ) -> None :
371
378
config = EverestConfig .load_file (config_file )
372
379
output_dir = Path (config .optimization_output_dir )
373
380
if os .path .exists (output_dir / "seba.db" ) or os .path .exists (
@@ -509,7 +516,7 @@ def _store_function_results(self, results: FunctionResults) -> _EvaluationResult
509
516
separator = ":" ,
510
517
)
511
518
512
- realization_objectives = realization_objectives .pivot (
519
+ realization_objectives = realization_objectives .pivot ( # type: ignore
513
520
values = "objective_value" ,
514
521
index = [
515
522
"batch_id" ,
@@ -673,14 +680,16 @@ def _on_batch_evaluation_finished(self, event: Event) -> None:
673
680
and item .functions is not None
674
681
and item .functions .weighted_objective > best_value
675
682
):
676
- best_value = item .functions .weighted_objective
683
+ best_value = float ( item .functions .weighted_objective )
677
684
best_results = item
678
685
679
686
if best_results is not None :
680
687
results = [best_results , * results ]
681
688
682
- batch_dicts = {}
689
+ batch_dicts : dict [ int , Any ] = {}
683
690
for item in results :
691
+ assert item .batch_id is not None
692
+
684
693
if item .batch_id not in batch_dicts :
685
694
batch_dicts [item .batch_id ] = {}
686
695
@@ -710,7 +719,7 @@ def _on_batch_evaluation_finished(self, event: Event) -> None:
710
719
)
711
720
)
712
721
713
- def _on_optimization_finished (self , _ ) -> None :
722
+ def _on_optimization_finished (self , _ : Any ) -> None :
714
723
logger .debug ("Storing final results Everest storage" )
715
724
716
725
merit_values = self ._get_merit_values ()
@@ -729,6 +738,7 @@ def _on_optimization_finished(self, _) -> None:
729
738
if merit_value is None :
730
739
continue
731
740
741
+ assert b .batch_objectives is not None
732
742
b .batch_objectives = b .batch_objectives .with_columns (
733
743
pl .lit (merit_value ).alias ("merit_value" )
734
744
)
@@ -754,8 +764,9 @@ def get_optimal_result(self) -> OptimalResult | None:
754
764
)
755
765
756
766
def find_best_batch (
757
- filter_by , sort_by
758
- ) -> tuple [BatchStorageData | None , dict | None ]:
767
+ filter_by : Callable [[BatchStorageData ], bool ],
768
+ sort_by : Callable [[BatchStorageData ], Any ],
769
+ ) -> tuple [BatchStorageData | None , dict [str , Any ] | None ]:
759
770
matching_batches = [b for b in self .data .batches if filter_by (b )]
760
771
761
772
if not matching_batches :
@@ -780,14 +791,17 @@ def find_best_batch(
780
791
b .batch_objectives is not None
781
792
and "merit_value" in b .batch_objectives .columns
782
793
),
783
- sort_by = lambda b : b .batch_objectives .select (
794
+ sort_by = lambda b : b .batch_objectives .select ( # type: ignore
784
795
pl .col ("merit_value" ).min ()
785
796
).item (),
786
797
)
787
798
788
799
if batch is None :
789
800
return None
790
801
802
+ assert controls_dict is not None
803
+ assert batch .batch_objectives is not None
804
+
791
805
return OptimalResult (
792
806
batch = batch .batch_id ,
793
807
controls = controls_dict ,
@@ -800,14 +814,17 @@ def find_best_batch(
800
814
batch , controls_dict = find_best_batch (
801
815
filter_by = lambda b : b .batch_objectives is not None
802
816
and not b .batch_objectives .is_empty (),
803
- sort_by = lambda b : - b .batch_objectives .select (
817
+ sort_by = lambda b : - b .batch_objectives .select ( # type: ignore
804
818
pl .col ("total_objective_value" ).sample (n = 1 )
805
819
).item (),
806
820
)
807
821
808
822
if batch is None :
809
823
return None
810
824
825
+ assert controls_dict is not None
826
+ assert batch .batch_objectives is not None
827
+
811
828
return OptimalResult (
812
829
batch = batch .batch_id ,
813
830
controls = controls_dict ,
0 commit comments