Skip to content

Commit 5b49260

Browse files
committed
Add run experiment with design matrix to ensemble experiment panel
- Catagorical data is not treated properly yet, wherein the design matrix group that contains catagorical data will automatically store all parameters inside this group to objects; ie, strings. - Prefil active realization box with realizations from design matrix - Use design_matrix parameters in ensemble experiment - add test run cli with design matrix and poly example - add test that save parameters internalize DataFrame parameters in the storage - add merge function to merge design parameters with existing parameters -- Raise Validation error when having multiple overlapping groups - Update writting to parameter.txt with categorical values
1 parent d875982 commit 5b49260

File tree

10 files changed

+519
-69
lines changed

10 files changed

+519
-69
lines changed

src/ert/config/design_matrix.py

+69-13
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from .parsing import ConfigValidationError, ErrorInfo
1515

1616
if TYPE_CHECKING:
17-
from ert.config import (
18-
ParameterConfig,
19-
)
17+
from ert.config import ParameterConfig
2018

2119
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"
2220

@@ -28,10 +26,17 @@ class DesignMatrix:
2826
default_sheet: str
2927

3028
def __post_init__(self) -> None:
31-
self.num_realizations: int | None = None
32-
self.active_realizations: list[bool] | None = None
33-
self.design_matrix_df: pd.DataFrame | None = None
34-
self.parameter_configuration: dict[str, ParameterConfig] | None = None
29+
try:
30+
(
31+
self.active_realizations,
32+
self.design_matrix_df,
33+
self.parameter_configuration,
34+
) = self.read_design_matrix()
35+
except (ValueError, AttributeError) as exc:
36+
raise ConfigValidationError.with_context(
37+
f"Error reading design matrix {self.xls_filename}: {exc}",
38+
str(self.xls_filename),
39+
) from exc
3540

3641
@classmethod
3742
def from_config_list(cls, config_list: list[str]) -> DesignMatrix:
@@ -73,9 +78,60 @@ def from_config_list(cls, config_list: list[str]) -> DesignMatrix:
7378
default_sheet=default_sheet,
7479
)
7580

81+
def merge_with_existing_parameters(
82+
self, existing_parameters: list[ParameterConfig]
83+
) -> tuple[list[ParameterConfig], ParameterConfig | None]:
84+
"""
85+
This method merges the design matrix parameters with the existing parameters and
86+
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
87+
GEN_KW group that was dropped will acquire a new name from the design matrix group.
88+
Additionally, the ParameterConfig which is the design matrix group is returned separately.
89+
90+
Args:
91+
existing_parameters (List[ParameterConfig]): List of existing parameters
92+
93+
Raises:
94+
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
95+
96+
Returns:
97+
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
98+
"""
99+
100+
new_param_config: list[ParameterConfig] = []
101+
102+
design_parameter_group = self.parameter_configuration[DESIGN_MATRIX_GROUP]
103+
design_keys = []
104+
if isinstance(design_parameter_group, GenKwConfig):
105+
design_keys = [e.name for e in design_parameter_group.transform_functions]
106+
107+
design_group_added = False
108+
for parameter_group in existing_parameters:
109+
if not isinstance(parameter_group, GenKwConfig):
110+
new_param_config += [parameter_group]
111+
continue
112+
existing_keys = [e.name for e in parameter_group.transform_functions]
113+
if set(existing_keys) == set(design_keys):
114+
if design_group_added:
115+
raise ConfigValidationError(
116+
"Multiple overlapping groups with design matrix found in existing parameters!\n"
117+
f"{design_parameter_group.name} and {parameter_group.name}"
118+
)
119+
120+
design_parameter_group.name = parameter_group.name
121+
design_group_added = True
122+
elif set(design_keys) & set(existing_keys):
123+
raise ConfigValidationError(
124+
"Overlapping parameter names found in design matrix!\n"
125+
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
126+
"\nThey need to much exactly or not at all."
127+
)
128+
else:
129+
new_param_config += [parameter_group]
130+
return new_param_config, design_parameter_group
131+
76132
def read_design_matrix(
77133
self,
78-
) -> None:
134+
) -> tuple[list[bool], pd.DataFrame, dict[str, ParameterConfig]]:
79135
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
80136
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
81137
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
@@ -139,11 +195,11 @@ def read_design_matrix(
139195
[[DESIGN_MATRIX_GROUP], design_matrix_df.columns]
140196
)
141197
reals = design_matrix_df.index.tolist()
142-
self.num_realizations = len(reals)
143-
self.active_realizations = [x in reals for x in range(max(reals) + 1)]
144-
145-
self.design_matrix_df = design_matrix_df
146-
self.parameter_configuration = parameter_configuration
198+
return (
199+
[x in reals for x in range(max(reals) + 1)],
200+
design_matrix_df,
201+
parameter_configuration,
202+
)
147203

148204
@staticmethod
149205
def _read_excel(

src/ert/enkf_main.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,17 @@
1010
from typing import TYPE_CHECKING, Any
1111

1212
import orjson
13+
import pandas as pd
14+
import xarray as xr
1315
from numpy.random import SeedSequence
1416

1517
from ert.config.ert_config import forward_model_data_to_json
1618
from ert.config.forward_model_step import ForwardModelStep
1719
from ert.config.model_config import ModelConfig
1820
from ert.substitutions import Substitutions, substitute_runpath_name
1921

20-
from .config import (
21-
ExtParamConfig,
22-
Field,
23-
GenKwConfig,
24-
ParameterConfig,
25-
SurfaceConfig,
26-
)
22+
from .config import ExtParamConfig, Field, GenKwConfig, ParameterConfig, SurfaceConfig
23+
from .config.design_matrix import DESIGN_MATRIX_GROUP
2724
from .run_arg import RunArg
2825
from .runpaths import Runpaths
2926

@@ -53,7 +50,10 @@ def _value_export_txt(
5350
with path.open("w") as f:
5451
for key, param_map in values.items():
5552
for param, value in param_map.items():
56-
print(f"{key}:{param} {value:g}", file=f)
53+
if isinstance(value, (int | float)):
54+
print(f"{key}:{param} {value:g}", file=f)
55+
else:
56+
print(f"{key}:{param} {value}", file=f)
5757

5858

5959
def _value_export_json(
@@ -156,6 +156,29 @@ def _seed_sequence(seed: int | None) -> int:
156156
return int_seed
157157

158158

159+
def save_design_matrix_to_ensemble(
160+
design_matrix_df: pd.DataFrame,
161+
ensemble: Ensemble,
162+
active_realizations: Iterable[int],
163+
design_group_name: str = DESIGN_MATRIX_GROUP,
164+
) -> None:
165+
assert not design_matrix_df.empty
166+
for realization_nr in active_realizations:
167+
row = design_matrix_df.loc[realization_nr][DESIGN_MATRIX_GROUP]
168+
ds = xr.Dataset(
169+
{
170+
"values": ("names", list(row.values)),
171+
"transformed_values": ("names", list(row.values)),
172+
"names": list(row.keys()),
173+
}
174+
)
175+
ensemble.save_parameters(
176+
design_group_name,
177+
realization_nr,
178+
ds,
179+
)
180+
181+
159182
def sample_prior(
160183
ensemble: Ensemble,
161184
active_realizations: Iterable[int],

src/ert/gui/simulation/ensemble_experiment_panel.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ert.gui.tools.design_matrix.design_matrix_panel import DesignMatrixPanel
1616
from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE
1717
from ert.run_models import EnsembleExperiment
18-
from ert.validation import RangeStringArgument
18+
from ert.validation import ActiveRange, RangeStringArgument
1919
from ert.validation.proper_name_argument import ExperimentValidation, ProperNameArgument
2020

2121
from .experiment_config_panel import ExperimentConfigPanel
@@ -85,6 +85,9 @@ def __init__(
8585

8686
design_matrix = analysis_config.design_matrix
8787
if design_matrix is not None:
88+
self._active_realizations_field.setText(
89+
ActiveRange(design_matrix.active_realizations).rangestring
90+
)
8891
show_dm_param_button = QPushButton("Show parameters")
8992
show_dm_param_button.setObjectName("show-dm-parameters")
9093
show_dm_param_button.setMinimumWidth(50)
@@ -113,23 +116,14 @@ def __init__(
113116
self.notifier.ertChanged.connect(self._update_experiment_name_placeholder)
114117

115118
def on_show_dm_params_clicked(self, design_matrix: DesignMatrix) -> None:
116-
assert design_matrix is not None
117-
118-
if design_matrix.design_matrix_df is None:
119-
design_matrix.read_design_matrix()
120-
121-
if (
122-
design_matrix.design_matrix_df is not None
123-
and not design_matrix.design_matrix_df.empty
124-
):
125-
viewer = DesignMatrixPanel(
126-
design_matrix.design_matrix_df,
127-
design_matrix.xls_filename.name,
128-
)
129-
viewer.setMinimumHeight(500)
130-
viewer.setMinimumWidth(1000)
131-
viewer.adjustSize()
132-
viewer.exec_()
119+
viewer = DesignMatrixPanel(
120+
design_matrix.design_matrix_df,
121+
design_matrix.xls_filename.name,
122+
)
123+
viewer.setMinimumHeight(500)
124+
viewer.setMinimumWidth(1000)
125+
viewer.adjustSize()
126+
viewer.exec_()
133127

134128
@Slot(ExperimentConfigPanel)
135129
def experimentTypeChanged(self, w: ExperimentConfigPanel) -> None:

src/ert/run_models/ensemble_experiment.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
import numpy as np
88

9-
from ert.enkf_main import sample_prior
9+
from ert.config import ConfigValidationError
10+
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
1011
from ert.ensemble_evaluator import EvaluatorServerConfig
1112
from ert.storage import Ensemble, Experiment, Storage
1213
from ert.trace import tracer
1314

1415
from ..run_arg import create_run_arguments
15-
from .base_run_model import BaseRunModel, StatusEvents
16+
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents
1617

1718
if TYPE_CHECKING:
1819
from ert.config import ErtConfig, QueueConfig
@@ -64,10 +65,27 @@ def run_experiment(
6465
) -> None:
6566
self.log_at_startup()
6667
self.restart = restart
68+
# If design matrix is present, we try to merge design matrix parameters
69+
# to the experiment parameters and set new active realizations
70+
parameters_config = self.ert_config.ensemble_config.parameter_configuration
71+
design_matrix = self.ert_config.analysis_config.design_matrix
72+
design_matrix_group = None
73+
if design_matrix is not None:
74+
try:
75+
parameters_config, design_matrix_group = (
76+
design_matrix.merge_with_existing_parameters(parameters_config)
77+
)
78+
except ConfigValidationError as exc:
79+
raise ErtRunError(str(exc)) from exc
80+
6781
if not restart:
6882
self.experiment = self._storage.create_experiment(
6983
name=self.experiment_name,
70-
parameters=self.ert_config.ensemble_config.parameter_configuration,
84+
parameters=(
85+
[*parameters_config, design_matrix_group]
86+
if design_matrix_group is not None
87+
else parameters_config
88+
),
7189
observations=self.ert_config.observations,
7290
responses=self.ert_config.ensemble_config.response_configuration,
7391
)
@@ -90,12 +108,21 @@ def run_experiment(
90108
np.array(self.active_realizations, dtype=bool),
91109
ensemble=self.ensemble,
92110
)
111+
93112
sample_prior(
94113
self.ensemble,
95114
np.where(self.active_realizations)[0],
96115
random_seed=self.random_seed,
97116
)
98117

118+
if design_matrix_group is not None and design_matrix is not None:
119+
save_design_matrix_to_ensemble(
120+
design_matrix.design_matrix_df,
121+
self.ensemble,
122+
np.where(self.active_realizations)[0],
123+
design_matrix_group.name,
124+
)
125+
99126
self._evaluate_and_postprocess(
100127
run_args,
101128
self.ensemble,

src/ert/run_models/model_factory.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,20 @@ def _setup_ensemble_experiment(
117117
args: Namespace,
118118
status_queue: SimpleQueue[StatusEvents],
119119
) -> EnsembleExperiment:
120-
active_realizations = _realizations(args, config.model_config.num_realizations)
120+
active_realizations = _realizations(
121+
args, config.model_config.num_realizations
122+
).tolist()
123+
if (
124+
config.analysis_config.design_matrix is not None
125+
and config.analysis_config.design_matrix.active_realizations is not None
126+
):
127+
active_realizations = config.analysis_config.design_matrix.active_realizations
121128
experiment_name = args.experiment_name
122129
assert experiment_name is not None
123130

124131
return EnsembleExperiment(
125132
random_seed=config.random_seed,
126-
active_realizations=active_realizations.tolist(),
133+
active_realizations=active_realizations,
127134
ensemble_name=args.current_ensemble,
128135
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
129136
experiment_name=experiment_name,
@@ -271,9 +278,9 @@ def _setup_iterative_ensemble_smoother(
271278
random_seed=config.random_seed,
272279
active_realizations=active_realizations.tolist(),
273280
target_ensemble=_iterative_ensemble_format(args),
274-
number_of_iterations=int(args.num_iterations)
275-
if args.num_iterations is not None
276-
else 4,
281+
number_of_iterations=(
282+
int(args.num_iterations) if args.num_iterations is not None else 4
283+
),
277284
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
278285
num_retries_per_iter=4,
279286
experiment_name=experiment_name,

0 commit comments

Comments
 (0)