Skip to content

Commit dfd6e4c

Browse files
committed
Add ScalarParameters config
- transformations become dataclases -> TODO: pydantic - ScalarParameter represents a single parameter instance with a datasource (DESIGN_MATRIX, SAMPLED) - all ScalarParameter instances are bound by ScalarParameters(ParameterConfig) and are stored for all realizations in a single pl.Dataframe (parquet) file. - API to load / save scalar parameters is in local_ensemble load|save_param_scalar - DESIGN_MATRIX now creates only ScalarParameters instances and modifies the ScalarParameterc Config. - Reimplement load_all_gen_kwn_data - scalarparameters: provide export to xr.Dataset - Update ensemble smoother - add pydantic validation to distributions - Fix ertsummary for scalars - ConfigValidationError on init_files for GEN_KW - Simple workaround for is_initialized with scalars - Replace GenKW with Scalars in test_when_manifest_files_are_written_forward_model_ok_succeeds
1 parent a03308a commit dfd6e4c

24 files changed

+1645
-682
lines changed

src/ert/analysis/_es_update.py

+34-19
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@
55
import time
66
from collections.abc import Callable, Iterable, Sequence
77
from fnmatch import fnmatch
8-
from typing import (
9-
TYPE_CHECKING,
10-
Generic,
11-
Self,
12-
TypeVar,
13-
)
8+
from typing import TYPE_CHECKING, Generic, Self, TypeVar
149

1510
import iterative_ensemble_smoother as ies
1611
import numpy as np
@@ -19,7 +14,13 @@
1914
import scipy
2015
from iterative_ensemble_smoother.experimental import AdaptiveESMDA
2116

22-
from ert.config import ESSettings, GenKwConfig, ObservationGroups, UpdateSettings
17+
from ert.config import (
18+
ESSettings,
19+
GenKwConfig,
20+
ObservationGroups,
21+
ScalarParameters,
22+
UpdateSettings,
23+
)
2324

2425
from . import misfit_preprocessor
2526
from .event import (
@@ -31,10 +32,7 @@
3132
AnalysisTimeEvent,
3233
DataSection,
3334
)
34-
from .snapshots import (
35-
ObservationAndResponseSnapshot,
36-
SmootherSnapshot,
37-
)
35+
from .snapshots import ObservationAndResponseSnapshot, SmootherSnapshot
3836

3937
if TYPE_CHECKING:
4038
import numpy.typing as npt
@@ -111,16 +109,26 @@ def _all_parameters(
111109

112110

113111
def _save_param_ensemble_array_to_disk(
114-
ensemble: Ensemble,
112+
source_ensemble: Ensemble,
113+
target_ensemble: Ensemble,
115114
param_ensemble_array: npt.NDArray[np.float64],
116115
param_group: str,
117116
iens_active_index: npt.NDArray[np.int_],
118117
) -> None:
119-
config_node = ensemble.experiment.parameter_configuration[param_group]
120-
for i, realization in enumerate(iens_active_index):
121-
config_node.save_parameters(
122-
ensemble, param_group, realization, param_ensemble_array[:, i]
118+
config_node = target_ensemble.experiment.parameter_configuration[param_group]
119+
if isinstance(config_node, ScalarParameters):
120+
config_node.save_updated_parameters_and_copy_remaining(
121+
source_ensemble,
122+
target_ensemble,
123+
param_group,
124+
iens_active_index,
125+
param_ensemble_array,
123126
)
127+
else:
128+
for i, realization in enumerate(iens_active_index):
129+
config_node.save_parameters(
130+
target_ensemble, param_group, realization, param_ensemble_array[:, i]
131+
)
124132

125133

126134
def _load_param_ensemble_array(
@@ -129,7 +137,11 @@ def _load_param_ensemble_array(
129137
iens_active_index: npt.NDArray[np.int_],
130138
) -> npt.NDArray[np.float64]:
131139
config_node = ensemble.experiment.parameter_configuration[param_group]
132-
return config_node.load_parameters(ensemble, param_group, iens_active_index)
140+
if isinstance(config_node, ScalarParameters):
141+
return config_node.load_parameters_to_update(ensemble, iens_active_index)
142+
dataset = config_node.load_parameters(ensemble, param_group, iens_active_index)
143+
assert isinstance(dataset, np.ndarray), "dataset is not an numpy array"
144+
return dataset
133145

134146

135147
def _expand_wildcards(
@@ -608,9 +620,12 @@ def correlation_callback(
608620
logger.info(log_msg)
609621
progress_callback(AnalysisStatusEvent(msg=log_msg))
610622
start = time.time()
611-
612623
_save_param_ensemble_array_to_disk(
613-
target_ensemble, param_ensemble_array, param_group, iens_active_index
624+
source_ensemble,
625+
target_ensemble,
626+
param_ensemble_array,
627+
param_group,
628+
iens_active_index,
614629
)
615630
logger.info(
616631
f"Storing data for {param_group} completed in {(time.time() - start) / 60} minutes"

src/ert/config/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
from .parsing.observations_parser import ObservationType
3232
from .queue_config import QueueConfig
3333
from .response_config import InvalidResponseFile, ResponseConfig
34+
from .scalar_parameter import (
35+
SCALAR_PARAMETERS_NAME,
36+
DataSource,
37+
ScalarParameter,
38+
ScalarParameters,
39+
get_distribution,
40+
)
3441
from .summary_config import SummaryConfig
3542
from .summary_observation import SummaryObservation
3643
from .surface_config import SurfaceConfig
@@ -39,11 +46,13 @@
3946

4047
__all__ = [
4148
"DESIGN_MATRIX_GROUP",
49+
"SCALAR_PARAMETERS_NAME",
4250
"AnalysisConfig",
4351
"AnalysisModule",
4452
"ConfigValidationError",
4553
"ConfigValidationError",
4654
"ConfigWarning",
55+
"DataSource",
4756
"DesignMatrix",
4857
"ESSettings",
4958
"EnkfObs",
@@ -70,6 +79,8 @@
7079
"QueueConfig",
7180
"QueueSystem",
7281
"ResponseConfig",
82+
"ScalarParameter",
83+
"ScalarParameters",
7384
"SummaryConfig",
7485
"SummaryObservation",
7586
"SurfaceConfig",
@@ -80,5 +91,6 @@
8091
"WorkflowJob",
8192
"capture_validation",
8293
"field_transform",
94+
"get_distribution",
8395
"lint_file",
8496
]

src/ert/config/design_matrix.py

+49-61
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88
import pandas as pd
99
from pandas.api.types import is_integer_dtype
1010

11-
from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition
12-
1311
from ._option_dict import option_dict
1412
from .parsing import ConfigValidationError, ErrorInfo
13+
from .scalar_parameter import (
14+
DataSource,
15+
ScalarParameter,
16+
ScalarParameters,
17+
TransRawSettings,
18+
)
1519

1620
if TYPE_CHECKING:
17-
from ert.config import ParameterConfig
21+
pass
1822

1923
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"
2024

@@ -32,7 +36,7 @@ def __post_init__(self) -> None:
3236
(
3337
self.active_realizations,
3438
self.design_matrix_df,
35-
self.parameter_configuration,
39+
self.scalars,
3640
) = self.read_and_validate_design_matrix()
3741
except (ValueError, AttributeError) as exc:
3842
raise ConfigValidationError.with_context(
@@ -102,66 +106,54 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:
102106
except ValueError as exc:
103107
errors.append(ErrorInfo(f"Error when merging design matrices {exc}!"))
104108

105-
for tfd in dm_other.parameter_configuration.transform_function_definitions:
106-
self.parameter_configuration.transform_function_definitions.append(tfd)
109+
for param in dm_other.scalars:
110+
self.scalars.append(param)
107111

108112
if errors:
109113
raise ConfigValidationError.from_collected(errors)
110114

111115
def merge_with_existing_parameters(
112-
self, existing_parameters: list[ParameterConfig]
113-
) -> tuple[list[ParameterConfig], GenKwConfig]:
116+
self, existing_scalars: ScalarParameters
117+
) -> ScalarParameters:
114118
"""
115119
This method merges the design matrix parameters with the existing parameters and
116-
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
117-
GEN_KW group that was dropped will acquire a new name from the design matrix group.
118-
Additionally, the ParameterConfig which is the design matrix group is returned separately.
119-
120+
returns the new list of existing parameters.
120121
Args:
121-
existing_parameters (List[ParameterConfig]): List of existing parameters
122+
existing_scalars (ScalarParameters): existing scalar parameters
122123
123-
Raises:
124-
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
125124
126125
Returns:
127-
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
126+
ScalarParameters: new set of ScalarParameters
128127
"""
129128

130-
new_param_config: list[ParameterConfig] = []
131-
132-
design_parameter_group = self.parameter_configuration
133-
design_keys = [e.name for e in design_parameter_group.transform_functions]
129+
all_params: list[ScalarParameter] = []
134130

135-
design_group_added = False
136-
for parameter_group in existing_parameters:
137-
if not isinstance(parameter_group, GenKwConfig):
138-
new_param_config += [parameter_group]
131+
overlap_set = set()
132+
for existing_parameter in existing_scalars.scalars:
133+
if existing_parameter.input_source == DataSource.DESIGN_MATRIX:
139134
continue
140-
existing_keys = [e.name for e in parameter_group.transform_functions]
141-
if set(existing_keys) == set(design_keys):
142-
if design_group_added:
143-
raise ConfigValidationError(
144-
"Multiple overlapping groups with design matrix found in existing parameters!\n"
145-
f"{design_parameter_group.name} and {parameter_group.name}"
146-
)
147-
148-
design_parameter_group.name = parameter_group.name
149-
design_parameter_group.template_file = parameter_group.template_file
150-
design_parameter_group.output_file = parameter_group.output_file
151-
design_group_added = True
152-
elif set(design_keys) & set(existing_keys):
153-
raise ConfigValidationError(
154-
"Overlapping parameter names found in design matrix!\n"
155-
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
156-
"\nThey need to match exactly or not at all."
157-
)
158-
else:
159-
new_param_config += [parameter_group]
160-
return new_param_config, design_parameter_group
135+
overlap = False
136+
for parameter_design in self.scalars:
137+
if existing_parameter.param_name == parameter_design.param_name:
138+
parameter_design.group_name = existing_parameter.group_name
139+
parameter_design.template_file = existing_parameter.template_file
140+
parameter_design.output_file = existing_parameter.output_file
141+
all_params.append(parameter_design)
142+
overlap = True
143+
overlap_set.add(existing_parameter.param_name)
144+
break
145+
if not overlap:
146+
all_params.append(existing_parameter)
147+
148+
for parameter_design in self.scalars:
149+
if parameter_design.param_name not in overlap_set:
150+
all_params.append(parameter_design)
151+
152+
return ScalarParameters(scalars=all_params)
161153

162154
def read_and_validate_design_matrix(
163155
self,
164-
) -> tuple[list[bool], pd.DataFrame, GenKwConfig]:
156+
) -> tuple[list[bool], pd.DataFrame, list[ScalarParameter]]:
165157
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
166158
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
167159
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
@@ -207,29 +199,25 @@ def read_and_validate_design_matrix(
207199

208200
design_matrix_df = pd.concat([design_matrix_df, default_df], axis=1)
209201

210-
transform_function_definitions: list[TransformFunctionDefinition] = []
202+
scalars: list[ScalarParameter] = []
211203
for parameter in design_matrix_df.columns:
212-
transform_function_definitions.append(
213-
TransformFunctionDefinition(
214-
name=parameter,
215-
param_name="RAW",
216-
values=[],
204+
scalars.append(
205+
ScalarParameter(
206+
param_name=parameter,
207+
group_name=DESIGN_MATRIX_GROUP,
208+
input_source=DataSource.DESIGN_MATRIX,
209+
distribution=TransRawSettings(),
210+
template_file=None,
211+
output_file=None,
212+
update=False,
217213
)
218214
)
219-
parameter_configuration = GenKwConfig(
220-
name=DESIGN_MATRIX_GROUP,
221-
forward_init=False,
222-
template_file=None,
223-
output_file=None,
224-
transform_function_definitions=transform_function_definitions,
225-
update=False,
226-
)
227215

228216
reals = design_matrix_df.index.tolist()
229217
return (
230218
[x in reals for x in range(max(reals) + 1)],
231219
design_matrix_df,
232-
parameter_configuration,
220+
scalars,
233221
)
234222

235223
@staticmethod

0 commit comments

Comments
 (0)