Skip to content

Commit 68d4b20

Browse files
committed
Add ScalarParameters config
- transformations become dataclases -> TODO: pydantic - ScalarParameter represents a single parameter instance with a datasource (DESIGN_MATRIX, SAMPLED) - all ScalarParameter instances are bound by ScalarParameters(ParameterConfig) and are stored for all realizations in a single pl.Dataframe (parquet) file. - API to load / save scalar parameters is in local_ensemble load|save_param_scalar - DESIGN_MATRIX now creates only ScalarParameters instances and modifies the ScalarParameterc Config. - Reimplement load_all_gen_kwn_data
1 parent 657656b commit 68d4b20

File tree

23 files changed

+1201
-462
lines changed

23 files changed

+1201
-462
lines changed

src/ert/analysis/_es_update.py

+34-19
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@
55
import time
66
from collections.abc import Callable, Iterable, Sequence
77
from fnmatch import fnmatch
8-
from typing import (
9-
TYPE_CHECKING,
10-
Generic,
11-
Self,
12-
TypeVar,
13-
)
8+
from typing import TYPE_CHECKING, Generic, Self, TypeVar
149

1510
import iterative_ensemble_smoother as ies
1611
import numpy as np
@@ -19,7 +14,13 @@
1914
import scipy
2015
from iterative_ensemble_smoother.experimental import AdaptiveESMDA
2116

22-
from ert.config import ESSettings, GenKwConfig, ObservationGroups, UpdateSettings
17+
from ert.config import (
18+
ESSettings,
19+
GenKwConfig,
20+
ObservationGroups,
21+
ScalarParameters,
22+
UpdateSettings,
23+
)
2324

2425
from . import misfit_preprocessor
2526
from .event import (
@@ -31,10 +32,7 @@
3132
AnalysisTimeEvent,
3233
DataSection,
3334
)
34-
from .snapshots import (
35-
ObservationAndResponseSnapshot,
36-
SmootherSnapshot,
37-
)
35+
from .snapshots import ObservationAndResponseSnapshot, SmootherSnapshot
3836

3937
if TYPE_CHECKING:
4038
import numpy.typing as npt
@@ -111,16 +109,26 @@ def _all_parameters(
111109

112110

113111
def _save_param_ensemble_array_to_disk(
114-
ensemble: Ensemble,
112+
source_ensemble: Ensemble,
113+
target_ensemble: Ensemble,
115114
param_ensemble_array: npt.NDArray[np.float64],
116115
param_group: str,
117116
iens_active_index: npt.NDArray[np.int_],
118117
) -> None:
119-
config_node = ensemble.experiment.parameter_configuration[param_group]
120-
for i, realization in enumerate(iens_active_index):
121-
config_node.save_parameters(
122-
ensemble, param_group, realization, param_ensemble_array[:, i]
118+
config_node = target_ensemble.experiment.parameter_configuration[param_group]
119+
if isinstance(config_node, ScalarParameters):
120+
config_node.save_updated_parameters_and_copy_remaining(
121+
source_ensemble,
122+
target_ensemble,
123+
param_group,
124+
iens_active_index,
125+
param_ensemble_array,
123126
)
127+
else:
128+
for i, realization in enumerate(iens_active_index):
129+
config_node.save_parameters(
130+
target_ensemble, param_group, realization, param_ensemble_array[:, i]
131+
)
124132

125133

126134
def _load_param_ensemble_array(
@@ -129,7 +137,11 @@ def _load_param_ensemble_array(
129137
iens_active_index: npt.NDArray[np.int_],
130138
) -> npt.NDArray[np.float64]:
131139
config_node = ensemble.experiment.parameter_configuration[param_group]
132-
return config_node.load_parameters(ensemble, param_group, iens_active_index)
140+
if isinstance(config_node, ScalarParameters):
141+
return config_node.load_parameters_to_update(ensemble, iens_active_index)
142+
dataset = config_node.load_parameters(ensemble, param_group, iens_active_index)
143+
assert isinstance(dataset, np.ndarray), "dataset is not an numpy array"
144+
return dataset
133145

134146

135147
def _expand_wildcards(
@@ -608,9 +620,12 @@ def correlation_callback(
608620
logger.info(log_msg)
609621
progress_callback(AnalysisStatusEvent(msg=log_msg))
610622
start = time.time()
611-
612623
_save_param_ensemble_array_to_disk(
613-
target_ensemble, param_ensemble_array, param_group, iens_active_index
624+
source_ensemble,
625+
target_ensemble,
626+
param_ensemble_array,
627+
param_group,
628+
iens_active_index,
614629
)
615630
logger.info(
616631
f"Storing data for {param_group} completed in {(time.time() - start) / 60} minutes"

src/ert/config/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .parsing.observations_parser import ObservationType
3232
from .queue_config import QueueConfig
3333
from .response_config import InvalidResponseFile, ResponseConfig
34+
from .scalar_parameter import DataSource, ScalarParameter, ScalarParameters
3435
from .summary_config import SummaryConfig
3536
from .summary_observation import SummaryObservation
3637
from .surface_config import SurfaceConfig
@@ -44,6 +45,7 @@
4445
"ConfigValidationError",
4546
"ConfigValidationError",
4647
"ConfigWarning",
48+
"DataSource",
4749
"DesignMatrix",
4850
"ESSettings",
4951
"EnkfObs",
@@ -70,6 +72,8 @@
7072
"QueueConfig",
7173
"QueueSystem",
7274
"ResponseConfig",
75+
"ScalarParameter",
76+
"ScalarParameters",
7377
"SummaryConfig",
7478
"SummaryObservation",
7579
"SurfaceConfig",

src/ert/config/design_matrix.py

+53-59
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@
88
import pandas as pd
99
from pandas.api.types import is_integer_dtype
1010

11-
from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition
12-
1311
from ._option_dict import option_dict
1412
from .parsing import ConfigValidationError, ErrorInfo
13+
from .scalar_parameter import (
14+
SCALAR_PARAMETERS_NAME,
15+
DataSource,
16+
ScalarParameter,
17+
ScalarParameters,
18+
TransRawSettings,
19+
)
1520

1621
if TYPE_CHECKING:
17-
from ert.config import ParameterConfig
22+
pass
1823

1924
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"
2025

@@ -32,7 +37,7 @@ def __post_init__(self) -> None:
3237
(
3338
self.active_realizations,
3439
self.design_matrix_df,
35-
self.parameter_configuration,
40+
self.scalars,
3641
) = self.read_design_matrix()
3742
except (ValueError, AttributeError) as exc:
3843
raise ConfigValidationError.with_context(
@@ -102,64 +107,57 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:
102107
except ValueError as exc:
103108
errors.append(ErrorInfo(f"Error when merging design matrices {exc}!"))
104109

105-
for tfd in dm_other.parameter_configuration.transform_function_definitions:
106-
self.parameter_configuration.transform_function_definitions.append(tfd)
110+
for param in dm_other.scalars:
111+
self.scalars.append(param)
107112

108113
if errors:
109114
raise ConfigValidationError.from_collected(errors)
110115

111116
def merge_with_existing_parameters(
112-
self, existing_parameters: list[ParameterConfig]
113-
) -> tuple[list[ParameterConfig], GenKwConfig]:
117+
self, existing_scalars: ScalarParameters
118+
) -> ScalarParameters:
114119
"""
115120
This method merges the design matrix parameters with the existing parameters and
116-
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
117-
GEN_KW group that was dropped will acquire a new name from the design matrix group.
118-
Additionally, the ParameterConfig which is the design matrix group is returned separately.
119-
121+
returns the new list of existing parameters.
120122
Args:
121-
existing_parameters (List[ParameterConfig]): List of existing parameters
123+
existing_scalars (ScalarParameters): existing scalar parameters
122124
123-
Raises:
124-
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
125125
126126
Returns:
127-
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
127+
ScalarParameters: new set of ScalarParameters
128128
"""
129129

130-
new_param_config: list[ParameterConfig] = []
130+
all_params: list[ScalarParameter] = []
131131

132-
design_parameter_group = self.parameter_configuration
133-
design_keys = [e.name for e in design_parameter_group.transform_functions]
134-
135-
design_group_added = False
136-
for parameter_group in existing_parameters:
137-
if not isinstance(parameter_group, GenKwConfig):
138-
new_param_config += [parameter_group]
132+
overlap_set = set()
133+
for parameter_sampled in existing_scalars.scalars:
134+
if parameter_sampled.input_source == DataSource.DESIGN_MATRIX:
139135
continue
140-
existing_keys = [e.name for e in parameter_group.transform_functions]
141-
if set(existing_keys) == set(design_keys):
142-
if design_group_added:
143-
raise ConfigValidationError(
144-
"Multiple overlapping groups with design matrix found in existing parameters!\n"
145-
f"{design_parameter_group.name} and {parameter_group.name}"
146-
)
147-
148-
design_parameter_group.name = parameter_group.name
149-
design_group_added = True
150-
elif set(design_keys) & set(existing_keys):
151-
raise ConfigValidationError(
152-
"Overlapping parameter names found in design matrix!\n"
153-
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
154-
"\nThey need to match exactly or not at all."
155-
)
156-
else:
157-
new_param_config += [parameter_group]
158-
return new_param_config, design_parameter_group
136+
overlap = False
137+
for parameter_design in self.scalars:
138+
if parameter_sampled.param_name == parameter_design.param_name:
139+
parameter_design.group_name = parameter_sampled.group_name
140+
all_params.append(parameter_design)
141+
overlap = True
142+
overlap_set.add(parameter_sampled.param_name)
143+
break
144+
if not overlap:
145+
all_params.append(parameter_sampled)
146+
147+
for parameter_design in self.scalars:
148+
if parameter_design.param_name not in overlap_set:
149+
all_params.append(parameter_design)
150+
151+
return ScalarParameters(
152+
name=SCALAR_PARAMETERS_NAME,
153+
forward_init=False,
154+
update=True,
155+
scalars=all_params,
156+
)
159157

160158
def read_design_matrix(
161159
self,
162-
) -> tuple[list[bool], pd.DataFrame, GenKwConfig]:
160+
) -> tuple[list[bool], pd.DataFrame, list[ScalarParameter]]:
163161
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
164162
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
165163
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
@@ -205,29 +203,25 @@ def read_design_matrix(
205203

206204
design_matrix_df = pd.concat([design_matrix_df, default_df], axis=1)
207205

208-
transform_function_definitions: list[TransformFunctionDefinition] = []
206+
scalars: list[ScalarParameter] = []
209207
for parameter in design_matrix_df.columns:
210-
transform_function_definitions.append(
211-
TransformFunctionDefinition(
212-
name=parameter,
213-
param_name="RAW",
214-
values=[],
208+
scalars.append(
209+
ScalarParameter(
210+
param_name=parameter,
211+
group_name=DESIGN_MATRIX_GROUP,
212+
input_source=DataSource.DESIGN_MATRIX,
213+
distribution=TransRawSettings(),
214+
template_file=None,
215+
output_file=None,
216+
update=False,
215217
)
216218
)
217-
parameter_configuration = GenKwConfig(
218-
name=DESIGN_MATRIX_GROUP,
219-
forward_init=False,
220-
template_file=None,
221-
output_file=None,
222-
transform_function_definitions=transform_function_definitions,
223-
update=False,
224-
)
225219

226220
reals = design_matrix_df.index.tolist()
227221
return (
228222
[x in reals for x in range(max(reals) + 1)],
229223
design_matrix_df,
230-
parameter_configuration,
224+
scalars,
231225
)
232226

233227
@staticmethod

src/ert/config/ensemble_config.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .parsing import ConfigDict, ConfigKeys, ConfigValidationError
1717
from .refcase import Refcase
1818
from .response_config import ResponseConfig
19+
from .scalar_parameter import ScalarParameters
1920
from .summary_config import SummaryConfig
2021
from .surface_config import SurfaceConfig
2122

@@ -47,7 +48,8 @@ class EnsembleConfig:
4748
default_factory=dict
4849
)
4950
parameter_configs: dict[
50-
str, GenKwConfig | FieldConfig | SurfaceConfig | ExtParamConfig
51+
str,
52+
GenKwConfig | FieldConfig | SurfaceConfig | ExtParamConfig | ScalarParameters,
5153
] = field(default_factory=dict)
5254
refcase: Refcase | None = None
5355

@@ -134,7 +136,7 @@ def make_field(field_list: list[str]) -> FieldConfig:
134136
return FieldConfig.from_config_list(grid_file_path, dims, field_list)
135137

136138
parameter_configs = (
137-
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
139+
[ScalarParameters.from_config_list(gen_kw_list)]
138140
+ [SurfaceConfig.from_config_list(s) for s in surface_list]
139141
+ [make_field(f) for f in field_list]
140142
)
@@ -180,12 +182,12 @@ def hasNodeGenData(self, key: str) -> bool:
180182
config = self.response_configs["gen_data"]
181183
return key in config.keys
182184

185+
# TODO: This might not be needed but it retrieves the group names for genkw config
183186
def get_keylist_gen_kw(self) -> list[str]:
184-
return [
185-
val.name
186-
for val in self.parameter_configuration
187-
if isinstance(val, GenKwConfig)
188-
]
187+
for val in self.parameter_configuration:
188+
if isinstance(val, ScalarParameters):
189+
return list(val.groups.keys())
190+
return []
189191

190192
@property
191193
def parameters(self) -> list[str]:

0 commit comments

Comments
 (0)