Skip to content

Commit 0718d01

Browse files
committed
Add function to load_subset of parameters from ScalarParameters
1 parent 66aa11b commit 0718d01

File tree

4 files changed

+68
-20
lines changed

4 files changed

+68
-20
lines changed

src/ert/analysis/_es_update.py

+34-15
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@
55
import time
66
from collections.abc import Callable, Iterable, Sequence
77
from fnmatch import fnmatch
8-
from typing import (
9-
TYPE_CHECKING,
10-
Generic,
11-
Self,
12-
TypeVar,
13-
)
8+
from typing import TYPE_CHECKING, Generic, Self, TypeVar
149

1510
import iterative_ensemble_smoother as ies
1611
import numpy as np
@@ -19,7 +14,13 @@
1914
import scipy
2015
from iterative_ensemble_smoother.experimental import AdaptiveESMDA
2116

22-
from ert.config import ESSettings, GenKwConfig, ObservationGroups, UpdateSettings
17+
from ert.config import (
18+
ESSettings,
19+
GenKwConfig,
20+
ObservationGroups,
21+
ScalarParameters,
22+
UpdateSettings,
23+
)
2324

2425
from . import misfit_preprocessor
2526
from .event import (
@@ -31,10 +32,7 @@
3132
AnalysisTimeEvent,
3233
DataSection,
3334
)
34-
from .snapshots import (
35-
ObservationAndResponseSnapshot,
36-
SmootherSnapshot,
37-
)
35+
from .snapshots import ObservationAndResponseSnapshot, SmootherSnapshot
3836

3937
if TYPE_CHECKING:
4038
import numpy.typing as npt
@@ -123,13 +121,29 @@ def _save_param_ensemble_array_to_disk(
123121
)
124122

125123

124+
def _load_param_ensemble_array_scalar(
125+
ensemble: Ensemble,
126+
scalar_params: list[str],
127+
iens_active_index: npt.NDArray[np.int_],
128+
) -> npt.NDArray[np.float64]:
129+
config_node = ensemble.experiment.parameter_configuration["SCALAR_PARAMETERS"]
130+
assert isinstance(config_node, ScalarParameters)
131+
return config_node.load_parameters_scalar(
132+
ensemble, scalar_params, iens_active_index
133+
)
134+
135+
126136
def _load_param_ensemble_array(
127137
ensemble: Ensemble,
128138
param_group: str,
129139
iens_active_index: npt.NDArray[np.int_],
130140
) -> npt.NDArray[np.float64]:
131141
config_node = ensemble.experiment.parameter_configuration[param_group]
132-
return config_node.load_parameters(ensemble, param_group, iens_active_index)
142+
if isinstance(config_node, ScalarParameters):
143+
raise ValueError("Config node is scalar!")
144+
dataset = config_node.load_parameters(ensemble, param_group, iens_active_index)
145+
assert isinstance(dataset, np.ndarray), "dataset is not an numpy array"
146+
return dataset
133147

134148

135149
def _expand_wildcards(
@@ -542,9 +556,14 @@ def correlation_callback(
542556
cross_correlations_accumulator.append(cross_correlations_of_batch)
543557

544558
for param_group in parameters:
545-
param_ensemble_array = _load_param_ensemble_array(
546-
source_ensemble, param_group, iens_active_index
547-
)
559+
if isinstance(param_group, list):
560+
param_ensemble_array = _load_param_ensemble_array_scalar(
561+
source_ensemble, iens_active_index
562+
)
563+
else:
564+
param_ensemble_array = _load_param_ensemble_array(
565+
source_ensemble, param_group, iens_active_index
566+
)
548567
if module.localization:
549568
config_node = source_ensemble.experiment.parameter_configuration[
550569
param_group

src/ert/config/parameter_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def save_parameters(
103103
@abstractmethod
104104
def load_parameters(
105105
self, ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
106-
) -> npt.NDArray[np.float64]:
106+
) -> npt.NDArray[np.float64] | xr.Dataset:
107107
"""
108108
Load the parameter from internal storage for the given ensemble.
109109
Must return array of shape (number of parameters, number of realizations).

src/ert/config/scalar_parameter.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ class ScalarParameter:
278278
| TransEmptySettings
279279
)
280280
input_source: DataSource
281+
update: bool = True
281282
# dataset_file: PolarsData | None
282283

283284

@@ -375,13 +376,21 @@ def sample_or_load(
375376

376377
return params_ds
377378

379+
def load_parameters_scalar(
380+
self,
381+
ensemble: Ensemble,
382+
scalar_params: list[str],
383+
iens_active_index: npt.NDArray[np.int_],
384+
) -> npt.NDArray[np.float64]:
385+
ds = self.load_parameters(ensemble, "SCALAR_PARAMETERS", iens_active_index)
386+
return ds.sel(type="transformed")[scalar_params].to_array().values.T
387+
378388
@staticmethod
379389
def load_parameters(
380390
ensemble: Ensemble, group: str, realizations: npt.NDArray[np.int_]
381391
) -> xr.Dataset:
382392
return xr.concat(
383393
[
384-
# ensemble.load_parameters_scalar(realization)
385394
ensemble.load_parameters(group, realizations)
386395
for realization in realizations
387396
],

src/ert/storage/local_experiment.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
import xtgeo
1414
from pydantic import BaseModel
1515

16-
from ert.config import ExtParamConfig, Field, GenKwConfig, SurfaceConfig
16+
from ert.config import (
17+
ExtParamConfig,
18+
Field,
19+
GenKwConfig,
20+
ScalarParameters,
21+
SurfaceConfig,
22+
)
1723
from ert.config.parsing.context_values import ContextBoolEncoder
1824
from ert.config.response_config import ResponseConfig
1925
from ert.storage.mode import BaseMode, Mode, require_write
@@ -298,9 +304,23 @@ def response_configuration(self) -> dict[str, ResponseConfig]:
298304

299305
return responses
300306

307+
# @cached_property
308+
# def update_parameters(self) -> list[str]:
309+
# return [p.name for p in self.parameter_configuration.values() if p.update]
301310
@cached_property
302-
def update_parameters(self) -> list[str]:
303-
return [p.name for p in self.parameter_configuration.values() if p.update]
311+
def update_parameters(self) -> list[str | list[str]]:
312+
params_to_update = []
313+
for _, config_node in self.parameter_configuration.items():
314+
scalar_params = []
315+
if isinstance(config_node, ScalarParameters):
316+
for param in config_node.scalar_params:
317+
if param.update:
318+
scalar_params.append(f"{param.group_name}:{param.param_name}")
319+
if scalar_params:
320+
params_to_update.append(scalar_params)
321+
elif config_node.update:
322+
params_to_update.append(config_node.name)
323+
return params_to_update
304324

305325
@cached_property
306326
def observations(self) -> dict[str, pl.DataFrame]:

0 commit comments

Comments
 (0)