Skip to content

Commit 85b56ad

Browse files
committed
Type everest/storage
1 parent f92f048 commit 85b56ad

File tree

2 files changed

+35
-37
lines changed

2 files changed

+35
-37
lines changed

.mypy.ini

-19
Original file line numberDiff line numberDiff line change
@@ -103,25 +103,6 @@ disable_error_code = dict-item,
103103
name-defined
104104

105105

106-
[mypy-everest.everest_storage.*]
107-
disable_error_code = dict-item,
108-
no-untyped-def,
109-
call-overload,
110-
union-attr,
111-
no-untyped-call,
112-
var-annotated,
113-
index,
114-
call-arg,
115-
unused-ignore,
116-
arg-type,
117-
type-arg,
118-
type-var,
119-
assignment,
120-
typeddict-item,
121-
attr-defined,
122-
comparison-overlap,
123-
return-value,
124-
name-defined
125106

126107

127108

src/everest/everest_storage.py

+35-18
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import traceback
7+
from collections.abc import Callable
78
from dataclasses import dataclass, field
89
from functools import partial
910
from pathlib import Path
@@ -24,7 +25,7 @@
2425
@dataclass
2526
class OptimalResult:
2627
batch: int
27-
controls: list[Any]
28+
controls: dict[str, Any]
2829
total_objective: float
2930

3031

@@ -155,7 +156,7 @@ def read_from_experiment(self, experiment: _OptimizerOnlyExperiment) -> None:
155156
self.batches.append(
156157
BatchStorageData(
157158
batch_id=info["batch_id"],
158-
**{
159+
**{ # type: ignore
159160
df_name: try_read_df(
160161
Path(ens.optimizer_mount_point) / f"{df_name}.parquet"
161162
)
@@ -198,8 +199,8 @@ class _OptimizerOnlyExperiment:
198199
"""
199200

200201
def __init__(self, output_dir: Path) -> None:
201-
self._output_dir = output_dir
202-
self._ensembles = {}
202+
self._output_dir: Path = output_dir
203+
self._ensembles: dict[str, _OptimizerOnlyEnsemble] = {}
203204

204205
@property
205206
def optimizer_mount_point(self) -> Path:
@@ -329,7 +330,7 @@ def _ropt_to_df(
329330
field: str,
330331
*,
331332
values: list[str],
332-
select: list,
333+
select: list[str],
333334
) -> pl.DataFrame:
334335
df = pl.from_pandas(
335336
results.to_dataframe(field, select=values).reset_index(),
@@ -340,17 +341,23 @@ def _ropt_to_df(
340341
# retrieved from the everest configuration and were stored in the init
341342
# method. Here we replace the indices with those names:
342343
ropt_to_everest_names = {
343-
"variable": self.data.controls["control_name"],
344-
"objective": self.data.objective_functions["objective_name"],
344+
"variable": self.data.controls["control_name"]
345+
if self.data.controls is not None
346+
else None,
347+
"objective": self.data.objective_functions["objective_name"]
348+
if self.data.objective_functions is not None
349+
else None,
345350
"nonlinear_constraint": (
346351
self.data.nonlinear_constraints["constraint_name"]
347352
if self.data.nonlinear_constraints is not None
348353
else None
349354
),
350-
"realization": self.data.realization_weights["realization"],
355+
"realization": self.data.realization_weights["realization"]
356+
if self.data.realization_weights is not None
357+
else None,
351358
}
352359
df = df.with_columns(
353-
pl.col(ropt_name).replace_strict(dict(enumerate(everest_names)))
360+
pl.col(ropt_name).replace_strict(dict(enumerate(everest_names))) # type: ignore
354361
for ropt_name, everest_names in ropt_to_everest_names.items()
355362
if ropt_name in select
356363
)
@@ -367,7 +374,7 @@ def write_to_output_dir(self) -> None:
367374
self.data.write_to_experiment(exp)
368375

369376
@staticmethod
370-
def check_for_deprecated_seba_storage(config_file: str):
377+
def check_for_deprecated_seba_storage(config_file: str) -> None:
371378
config = EverestConfig.load_file(config_file)
372379
output_dir = Path(config.optimization_output_dir)
373380
if os.path.exists(output_dir / "seba.db") or os.path.exists(
@@ -509,7 +516,7 @@ def _store_function_results(self, results: FunctionResults) -> _EvaluationResult
509516
separator=":",
510517
)
511518

512-
realization_objectives = realization_objectives.pivot(
519+
realization_objectives = realization_objectives.pivot( # type: ignore
513520
values="objective_value",
514521
index=[
515522
"batch_id",
@@ -673,14 +680,16 @@ def _on_batch_evaluation_finished(self, event: Event) -> None:
673680
and item.functions is not None
674681
and item.functions.weighted_objective > best_value
675682
):
676-
best_value = item.functions.weighted_objective
683+
best_value = float(item.functions.weighted_objective)
677684
best_results = item
678685

679686
if best_results is not None:
680687
results = [best_results, *results]
681688

682-
batch_dicts = {}
689+
batch_dicts: dict[int, Any] = {}
683690
for item in results:
691+
assert item.batch_id is not None
692+
684693
if item.batch_id not in batch_dicts:
685694
batch_dicts[item.batch_id] = {}
686695

@@ -710,7 +719,7 @@ def _on_batch_evaluation_finished(self, event: Event) -> None:
710719
)
711720
)
712721

713-
def _on_optimization_finished(self, _) -> None:
722+
def _on_optimization_finished(self, _: Any) -> None:
714723
logger.debug("Storing final results Everest storage")
715724

716725
merit_values = self._get_merit_values()
@@ -729,6 +738,7 @@ def _on_optimization_finished(self, _) -> None:
729738
if merit_value is None:
730739
continue
731740

741+
assert b.batch_objectives is not None
732742
b.batch_objectives = b.batch_objectives.with_columns(
733743
pl.lit(merit_value).alias("merit_value")
734744
)
@@ -754,8 +764,9 @@ def get_optimal_result(self) -> OptimalResult | None:
754764
)
755765

756766
def find_best_batch(
757-
filter_by, sort_by
758-
) -> tuple[BatchStorageData | None, dict | None]:
767+
filter_by: Callable[[BatchStorageData], bool],
768+
sort_by: Callable[[BatchStorageData], Any],
769+
) -> tuple[BatchStorageData | None, dict[str, Any] | None]:
759770
matching_batches = [b for b in self.data.batches if filter_by(b)]
760771

761772
if not matching_batches:
@@ -780,14 +791,17 @@ def find_best_batch(
780791
b.batch_objectives is not None
781792
and "merit_value" in b.batch_objectives.columns
782793
),
783-
sort_by=lambda b: b.batch_objectives.select(
794+
sort_by=lambda b: b.batch_objectives.select( # type: ignore
784795
pl.col("merit_value").min()
785796
).item(),
786797
)
787798

788799
if batch is None:
789800
return None
790801

802+
assert controls_dict is not None
803+
assert batch.batch_objectives is not None
804+
791805
return OptimalResult(
792806
batch=batch.batch_id,
793807
controls=controls_dict,
@@ -800,14 +814,17 @@ def find_best_batch(
800814
batch, controls_dict = find_best_batch(
801815
filter_by=lambda b: b.batch_objectives is not None
802816
and not b.batch_objectives.is_empty(),
803-
sort_by=lambda b: -b.batch_objectives.select(
817+
sort_by=lambda b: -b.batch_objectives.select( # type: ignore
804818
pl.col("total_objective_value").sample(n=1)
805819
).item(),
806820
)
807821

808822
if batch is None:
809823
return None
810824

825+
assert controls_dict is not None
826+
assert batch.batch_objectives is not None
827+
811828
return OptimalResult(
812829
batch=batch.batch_id,
813830
controls=controls_dict,

0 commit comments

Comments
 (0)