Skip to content

Commit ebef6f6

Browse files
committed
Add function to save and copy non updated scalar parameters
1 parent 0718d01 commit ebef6f6

File tree

3 files changed

+61
-55
lines changed

3 files changed

+61
-55
lines changed

src/ert/analysis/_es_update.py

+26-28
Original file line numberDiff line numberDiff line change
@@ -109,28 +109,27 @@ def _all_parameters(
109109

110110

111111
def _save_param_ensemble_array_to_disk(
112-
ensemble: Ensemble,
112+
source_ensemble: Ensemble,
113+
target_ensemble: Ensemble,
113114
param_ensemble_array: npt.NDArray[np.float64],
114115
param_group: str,
115116
iens_active_index: npt.NDArray[np.int_],
116117
) -> None:
117-
config_node = ensemble.experiment.parameter_configuration[param_group]
118-
for i, realization in enumerate(iens_active_index):
119-
config_node.save_parameters(
120-
ensemble, param_group, realization, param_ensemble_array[:, i]
121-
)
122-
123-
124-
def _load_param_ensemble_array_scalar(
125-
ensemble: Ensemble,
126-
scalar_params: list[str],
127-
iens_active_index: npt.NDArray[np.int_],
128-
) -> npt.NDArray[np.float64]:
129-
config_node = ensemble.experiment.parameter_configuration["SCALAR_PARAMETERS"]
130-
assert isinstance(config_node, ScalarParameters)
131-
return config_node.load_parameters_scalar(
132-
ensemble, scalar_params, iens_active_index
133-
)
118+
config_node = target_ensemble.experiment.parameter_configuration[param_group]
119+
if isinstance(config_node, ScalarParameters):
120+
for i, realization in enumerate(iens_active_index):
121+
config_node.save_updated_parameters_and_copy_remaining(
122+
source_ensemble,
123+
target_ensemble,
124+
param_group,
125+
realization,
126+
param_ensemble_array[:, i],
127+
)
128+
else:
129+
for i, realization in enumerate(iens_active_index):
130+
config_node.save_parameters(
131+
target_ensemble, param_group, realization, param_ensemble_array[:, i]
132+
)
134133

135134

136135
def _load_param_ensemble_array(
@@ -140,7 +139,7 @@ def _load_param_ensemble_array(
140139
) -> npt.NDArray[np.float64]:
141140
config_node = ensemble.experiment.parameter_configuration[param_group]
142141
if isinstance(config_node, ScalarParameters):
143-
raise ValueError("Config node is scalar!")
142+
return config_node.load_parameters_to_update(ensemble, iens_active_index)
144143
dataset = config_node.load_parameters(ensemble, param_group, iens_active_index)
145144
assert isinstance(dataset, np.ndarray), "dataset is not an numpy array"
146145
return dataset
@@ -556,14 +555,9 @@ def correlation_callback(
556555
cross_correlations_accumulator.append(cross_correlations_of_batch)
557556

558557
for param_group in parameters:
559-
if isinstance(param_group, list):
560-
param_ensemble_array = _load_param_ensemble_array_scalar(
561-
source_ensemble, iens_active_index
562-
)
563-
else:
564-
param_ensemble_array = _load_param_ensemble_array(
565-
source_ensemble, param_group, iens_active_index
566-
)
558+
param_ensemble_array = _load_param_ensemble_array(
559+
source_ensemble, param_group, iens_active_index
560+
)
567561
if module.localization:
568562
config_node = source_ensemble.experiment.parameter_configuration[
569563
param_group
@@ -629,7 +623,11 @@ def correlation_callback(
629623
start = time.time()
630624

631625
_save_param_ensemble_array_to_disk(
632-
target_ensemble, param_ensemble_array, param_group, iens_active_index
626+
source_ensemble,
627+
target_ensemble,
628+
param_ensemble_array,
629+
param_group,
630+
iens_active_index,
633631
)
634632
logger.info(
635633
f"Storing data for {param_group} completed in {(time.time() - start) / 60} minutes"

src/ert/config/scalar_parameter.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -376,14 +376,20 @@ def sample_or_load(
376376

377377
return params_ds
378378

379-
def load_parameters_scalar(
379+
def load_parameters_to_update(
380380
self,
381381
ensemble: Ensemble,
382-
scalar_params: list[str],
383382
iens_active_index: npt.NDArray[np.int_],
384383
) -> npt.NDArray[np.float64]:
384+
params_to_update = [
385+
f"{param.group_name}:{param.param_name}"
386+
for param in self.scalar_params
387+
if param.update
388+
]
389+
if not params_to_update:
390+
raise ValueError("No parameters to update")
385391
ds = self.load_parameters(ensemble, "SCALAR_PARAMETERS", iens_active_index)
386-
return ds.sel(type="transformed")[scalar_params].to_array().values.T
392+
return ds.sel(type="transformed")[params_to_update].to_array().values
387393

388394
@staticmethod
389395
def load_parameters(
@@ -404,9 +410,31 @@ def save_parameters(
404410
realization: int,
405411
data: npt.NDArray[np.float64],
406412
) -> None:
407-
# This would require to change the API potentially
413+
# this function is not used in the current implementation
408414
pass
409415

416+
def save_updated_parameters_and_copy_remaining(
417+
self,
418+
source_ensemble: Ensemble,
419+
target_ensemble: Ensemble,
420+
group: str,
421+
realization: int,
422+
data: npt.NDArray[np.float64],
423+
) -> None:
424+
params_to_update = [
425+
f"{param.group_name}:{param.param_name}"
426+
for param in self.scalar_params
427+
if param.update
428+
]
429+
ds = self.load_parameters(source_ensemble, group, np.array(realization))
430+
ds.update(
431+
{
432+
var: (["type"], np.where(ds["type"] == "raw", data, ds[var]))
433+
for i, var in enumerate(params_to_update)
434+
}
435+
)
436+
target_ensemble.save_parameters(group, realization, ds)
437+
410438
def __len__(self) -> int:
411439
return len(self.scalar_params)
412440

src/ert/storage/local_experiment.py

+3-23
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313
import xtgeo
1414
from pydantic import BaseModel
1515

16-
from ert.config import (
17-
ExtParamConfig,
18-
Field,
19-
GenKwConfig,
20-
ScalarParameters,
21-
SurfaceConfig,
22-
)
16+
from ert.config import ExtParamConfig, Field, GenKwConfig, SurfaceConfig
2317
from ert.config.parsing.context_values import ContextBoolEncoder
2418
from ert.config.response_config import ResponseConfig
2519
from ert.storage.mode import BaseMode, Mode, require_write
@@ -304,23 +298,9 @@ def response_configuration(self) -> dict[str, ResponseConfig]:
304298

305299
return responses
306300

307-
# @cached_property
308-
# def update_parameters(self) -> list[str]:
309-
# return [p.name for p in self.parameter_configuration.values() if p.update]
310301
@cached_property
311-
def update_parameters(self) -> list[str | list[str]]:
312-
params_to_update = []
313-
for _, config_node in self.parameter_configuration.items():
314-
scalar_params = []
315-
if isinstance(config_node, ScalarParameters):
316-
for param in config_node.scalar_params:
317-
if param.update:
318-
scalar_params.append(f"{param.group_name}:{param.param_name}")
319-
if scalar_params:
320-
params_to_update.append(scalar_params)
321-
elif config_node.update:
322-
params_to_update.append(config_node.name)
323-
return params_to_update
302+
def update_parameters(self) -> list[str]:
303+
return [p.name for p in self.parameter_configuration.values() if p.update]
324304

325305
@cached_property
326306
def observations(self) -> dict[str, pl.DataFrame]:

0 commit comments

Comments
 (0)