Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a670fcb

Browse files
committedDec 10, 2024·
Add run experiment with design matrix to ensemble experiment panel
- Prefil active realization box with realizations from design matrix - Use design_matrix parameters in ensemble experiment - add test run cli with design matrix and poly example - add test that save parameters internalize DataFrame parameters in the storage - add merge function to merge design parameters with existing parameters -- Raise Validation error when having multiple overlapping groups
1 parent bf2d515 commit a670fcb

File tree

10 files changed

+516
-72
lines changed

10 files changed

+516
-72
lines changed
 

‎src/ert/config/design_matrix.py

+74-17
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,10 @@
1111
from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition
1212

1313
from ._option_dict import option_dict
14-
from .parsing import (
15-
ConfigValidationError,
16-
ErrorInfo,
17-
)
14+
from .parsing import ConfigValidationError, ErrorInfo
1815

1916
if TYPE_CHECKING:
20-
from ert.config import (
21-
ParameterConfig,
22-
)
17+
from ert.config import ParameterConfig
2318

2419
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"
2520

@@ -31,10 +26,17 @@ class DesignMatrix:
3126
default_sheet: str
3227

3328
def __post_init__(self) -> None:
34-
self.num_realizations: Optional[int] = None
35-
self.active_realizations: Optional[List[bool]] = None
36-
self.design_matrix_df: Optional[pd.DataFrame] = None
37-
self.parameter_configuration: Optional[Dict[str, ParameterConfig]] = None
29+
try:
30+
(
31+
self.active_realizations,
32+
self.design_matrix_df,
33+
self.parameter_configuration,
34+
) = self.read_design_matrix()
35+
except (ValueError, AttributeError) as exc:
36+
raise ConfigValidationError.with_context(
37+
f"Error reading design matrix {self.xls_filename}: {exc}",
38+
str(self.xls_filename),
39+
) from exc
3840

3941
@classmethod
4042
def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
@@ -76,9 +78,64 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
7678
default_sheet=default_sheet,
7779
)
7880

81+
def merge_with_existing_parameters(
82+
self, existing_parameters: List[ParameterConfig]
83+
) -> tuple[List[ParameterConfig], ParameterConfig | None]:
84+
"""
85+
This method merges the design matrix parameters with the existing parameters and
86+
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
87+
GEN_KW group that was dropped will acquire a new name from the design matrix group.
88+
Additionally, the ParameterConfig which is the design matrix group is returned separately.
89+
90+
Args:
91+
existing_parameters (List[ParameterConfig]): List of existing parameters
92+
93+
Raises:
94+
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
95+
96+
Returns:
97+
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
98+
"""
99+
100+
new_param_config: List[ParameterConfig] = []
101+
102+
design_parameter_group = self.parameter_configuration[DESIGN_MATRIX_GROUP]
103+
design_keys = []
104+
if isinstance(design_parameter_group, GenKwConfig):
105+
design_keys = [e.name for e in design_parameter_group.transform_functions]
106+
107+
design_group_added = False
108+
for parameter_group in existing_parameters:
109+
if not isinstance(parameter_group, GenKwConfig):
110+
new_param_config += [parameter_group]
111+
continue
112+
existing_keys = [e.name for e in parameter_group.transform_functions]
113+
if set(existing_keys) == set(design_keys):
114+
if design_group_added:
115+
raise ConfigValidationError(
116+
(
117+
"Multiple overlapping groups with design matrix found in existing parameters!\n"
118+
f"{design_parameter_group.name} and {parameter_group.name}"
119+
)
120+
)
121+
122+
design_parameter_group.name = parameter_group.name
123+
design_group_added = True
124+
elif set(design_keys) & set(existing_keys):
125+
raise ConfigValidationError(
126+
(
127+
"Overlapping parameter names found in design matrix!\n"
128+
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
129+
"\nThey need to much exactly or not at all."
130+
)
131+
)
132+
else:
133+
new_param_config += [parameter_group]
134+
return new_param_config, design_parameter_group
135+
79136
def read_design_matrix(
80137
self,
81-
) -> None:
138+
) -> tuple[List[bool], pd.DataFrame, Dict[str, ParameterConfig]]:
82139
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
83140
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
84141
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
@@ -142,11 +199,11 @@ def read_design_matrix(
142199
[[DESIGN_MATRIX_GROUP], design_matrix_df.columns]
143200
)
144201
reals = design_matrix_df.index.tolist()
145-
self.num_realizations = len(reals)
146-
self.active_realizations = [x in reals for x in range(max(reals) + 1)]
147-
148-
self.design_matrix_df = design_matrix_df
149-
self.parameter_configuration = parameter_configuration
202+
return (
203+
[x in reals for x in range(max(reals) + 1)],
204+
design_matrix_df,
205+
parameter_configuration,
206+
)
150207

151208
@staticmethod
152209
def _read_excel(

‎src/ert/enkf_main.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,17 @@
1919
)
2020

2121
import orjson
22+
import pandas as pd
23+
import xarray as xr
2224
from numpy.random import SeedSequence
2325

2426
from ert.config.ert_config import forward_model_data_to_json
2527
from ert.config.forward_model_step import ForwardModelStep
2628
from ert.config.model_config import ModelConfig
2729
from ert.substitutions import Substitutions, substitute_runpath_name
2830

29-
from .config import (
30-
ExtParamConfig,
31-
Field,
32-
GenKwConfig,
33-
ParameterConfig,
34-
SurfaceConfig,
35-
)
31+
from .config import ExtParamConfig, Field, GenKwConfig, ParameterConfig, SurfaceConfig
32+
from .config.design_matrix import DESIGN_MATRIX_GROUP
3633
from .run_arg import RunArg
3734
from .runpaths import Runpaths
3835

@@ -165,6 +162,29 @@ def _seed_sequence(seed: Optional[int]) -> int:
165162
return int_seed
166163

167164

165+
def save_design_matrix_to_ensemble(
166+
design_matrix_df: pd.DataFrame,
167+
ensemble: Ensemble,
168+
active_realizations: Iterable[int],
169+
design_group_name: str = DESIGN_MATRIX_GROUP,
170+
) -> None:
171+
assert not design_matrix_df.empty
172+
for realization_nr in active_realizations:
173+
row = design_matrix_df.loc[realization_nr][DESIGN_MATRIX_GROUP]
174+
ds = xr.Dataset(
175+
{
176+
"values": ("names", list(row.values)),
177+
"transformed_values": ("names", list(row.values)),
178+
"names": list(row.keys()),
179+
}
180+
)
181+
ensemble.save_parameters(
182+
design_group_name,
183+
realization_nr,
184+
ds,
185+
)
186+
187+
168188
def sample_prior(
169189
ensemble: Ensemble,
170190
active_realizations: Iterable[int],

‎src/ert/gui/simulation/ensemble_experiment_panel.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ert.gui.tools.design_matrix.design_matrix_panel import DesignMatrixPanel
1616
from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE
1717
from ert.run_models import EnsembleExperiment
18-
from ert.validation import RangeStringArgument
18+
from ert.validation import ActiveRange, RangeStringArgument
1919
from ert.validation.proper_name_argument import ExperimentValidation, ProperNameArgument
2020

2121
from .experiment_config_panel import ExperimentConfigPanel
@@ -85,6 +85,9 @@ def __init__(
8585

8686
design_matrix = analysis_config.design_matrix
8787
if design_matrix is not None:
88+
self._active_realizations_field.setText(
89+
ActiveRange(design_matrix.active_realizations).rangestring
90+
)
8891
show_dm_param_button = QPushButton("Show parameters")
8992
show_dm_param_button.setObjectName("show-dm-parameters")
9093
show_dm_param_button.setMinimumWidth(50)
@@ -113,23 +116,14 @@ def __init__(
113116
self.notifier.ertChanged.connect(self._update_experiment_name_placeholder)
114117

115118
def on_show_dm_params_clicked(self, design_matrix: DesignMatrix) -> None:
116-
assert design_matrix is not None
117-
118-
if design_matrix.design_matrix_df is None:
119-
design_matrix.read_design_matrix()
120-
121-
if (
122-
design_matrix.design_matrix_df is not None
123-
and not design_matrix.design_matrix_df.empty
124-
):
125-
viewer = DesignMatrixPanel(
126-
design_matrix.design_matrix_df,
127-
design_matrix.xls_filename.name,
128-
)
129-
viewer.setMinimumHeight(500)
130-
viewer.setMinimumWidth(1000)
131-
viewer.adjustSize()
132-
viewer.exec_()
119+
viewer = DesignMatrixPanel(
120+
design_matrix.design_matrix_df,
121+
design_matrix.xls_filename.name,
122+
)
123+
viewer.setMinimumHeight(500)
124+
viewer.setMinimumWidth(1000)
125+
viewer.adjustSize()
126+
viewer.exec_()
133127

134128
@Slot(ExperimentConfigPanel)
135129
def experimentTypeChanged(self, w: ExperimentConfigPanel) -> None:

‎src/ert/run_models/ensemble_experiment.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
import numpy as np
88

9-
from ert.enkf_main import sample_prior
9+
from ert.config import ConfigValidationError
10+
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
1011
from ert.ensemble_evaluator import EvaluatorServerConfig
1112
from ert.storage import Ensemble, Experiment, Storage
1213
from ert.trace import tracer
1314

1415
from ..run_arg import create_run_arguments
15-
from .base_run_model import BaseRunModel, StatusEvents
16+
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents
1617

1718
if TYPE_CHECKING:
1819
from ert.config import ErtConfig, QueueConfig
@@ -64,10 +65,27 @@ def run_experiment(
6465
) -> None:
6566
self.log_at_startup()
6667
self.restart = restart
68+
# If design matrix is present, we try to merge design matrix parameters
69+
# to the experiment parameters and set new active realizations
70+
parameters_config = self.ert_config.ensemble_config.parameter_configuration
71+
design_matrix = self.ert_config.analysis_config.design_matrix
72+
design_matrix_group = None
73+
if design_matrix is not None:
74+
try:
75+
parameters_config, design_matrix_group = (
76+
design_matrix.merge_with_existing_parameters(parameters_config)
77+
)
78+
except ConfigValidationError as exc:
79+
raise ErtRunError(str(exc)) from exc
80+
6781
if not restart:
6882
self.experiment = self._storage.create_experiment(
6983
name=self.experiment_name,
70-
parameters=self.ert_config.ensemble_config.parameter_configuration,
84+
parameters=(
85+
[*parameters_config, design_matrix_group]
86+
if design_matrix_group is not None
87+
else parameters_config
88+
),
7189
observations=self.ert_config.observations,
7290
responses=self.ert_config.ensemble_config.response_configuration,
7391
)
@@ -90,12 +108,21 @@ def run_experiment(
90108
np.array(self.active_realizations, dtype=bool),
91109
ensemble=self.ensemble,
92110
)
111+
93112
sample_prior(
94113
self.ensemble,
95114
np.where(self.active_realizations)[0],
96115
random_seed=self.random_seed,
97116
)
98117

118+
if design_matrix_group is not None and design_matrix is not None:
119+
save_design_matrix_to_ensemble(
120+
design_matrix.design_matrix_df,
121+
self.ensemble,
122+
np.where(self.active_realizations)[0],
123+
design_matrix_group.name,
124+
)
125+
99126
self._evaluate_and_postprocess(
100127
run_args,
101128
self.ensemble,

‎src/ert/run_models/model_factory.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,20 @@ def _setup_ensemble_experiment(
117117
args: Namespace,
118118
status_queue: SimpleQueue[StatusEvents],
119119
) -> EnsembleExperiment:
120-
active_realizations = _realizations(args, config.model_config.num_realizations)
120+
active_realizations = _realizations(
121+
args, config.model_config.num_realizations
122+
).tolist()
123+
if (
124+
config.analysis_config.design_matrix is not None
125+
and config.analysis_config.design_matrix.active_realizations is not None
126+
):
127+
active_realizations = config.analysis_config.design_matrix.active_realizations
121128
experiment_name = args.experiment_name
122129
assert experiment_name is not None
123130

124131
return EnsembleExperiment(
125132
random_seed=config.random_seed,
126-
active_realizations=active_realizations.tolist(),
133+
active_realizations=active_realizations,
127134
ensemble_name=args.current_ensemble,
128135
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
129136
experiment_name=experiment_name,
@@ -271,9 +278,9 @@ def _setup_iterative_ensemble_smoother(
271278
random_seed=config.random_seed,
272279
active_realizations=active_realizations.tolist(),
273280
target_ensemble=_iterative_ensemble_format(args),
274-
number_of_iterations=int(args.num_iterations)
275-
if args.num_iterations is not None
276-
else 4,
281+
number_of_iterations=(
282+
int(args.num_iterations) if args.num_iterations is not None else 4
283+
),
277284
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
278285
num_retries_per_iter=4,
279286
experiment_name=experiment_name,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import os
2+
import stat
3+
from textwrap import dedent
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pytest
8+
9+
from ert.cli.main import ErtCliError
10+
from ert.config import ErtConfig
11+
from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE
12+
from ert.storage import open_storage
13+
from tests.ert.ui_tests.cli.run_cli import run_cli
14+
15+
16+
@pytest.mark.usefixtures("copy_poly_case")
17+
def test_run_poly_example_with_design_matrix():
18+
design_matrix = "poly_design.xlsx"
19+
num_realizations = 10
20+
a_values = list(range(num_realizations))
21+
design_matrix_df = pd.DataFrame(
22+
{
23+
"REAL": list(range(num_realizations)),
24+
"a": a_values,
25+
}
26+
)
27+
default_sheet_df = pd.DataFrame([["b", 1], ["c", 2]])
28+
with pd.ExcelWriter(design_matrix) as xl_write:
29+
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
30+
default_sheet_df.to_excel(
31+
xl_write, index=False, sheet_name="DefaultSheet", header=False
32+
)
33+
34+
with open("poly.ert", "w", encoding="utf-8") as fout:
35+
fout.write(
36+
dedent(
37+
"""\
38+
QUEUE_OPTION LOCAL MAX_RUNNING 10
39+
RUNPATH poly_out/realization-<IENS>/iter-<ITER>
40+
NUM_REALIZATIONS 10
41+
MIN_REALIZATIONS 1
42+
GEN_DATA POLY_RES RESULT_FILE:poly.out
43+
DESIGN_MATRIX poly_design.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultSheet
44+
INSTALL_JOB poly_eval POLY_EVAL
45+
FORWARD_MODEL poly_eval
46+
"""
47+
)
48+
)
49+
50+
with open("poly_eval.py", "w", encoding="utf-8") as f:
51+
f.write(
52+
dedent(
53+
"""\
54+
#!/usr/bin/env python
55+
import json
56+
57+
def _load_coeffs(filename):
58+
with open(filename, encoding="utf-8") as f:
59+
return json.load(f)["DESIGN_MATRIX"]
60+
61+
def _evaluate(coeffs, x):
62+
return coeffs["a"] * x**2 + coeffs["b"] * x + coeffs["c"]
63+
64+
if __name__ == "__main__":
65+
coeffs = _load_coeffs("parameters.json")
66+
output = [_evaluate(coeffs, x) for x in range(10)]
67+
with open("poly.out", "w", encoding="utf-8") as f:
68+
f.write("\\n".join(map(str, output)))
69+
"""
70+
)
71+
)
72+
os.chmod(
73+
"poly_eval.py",
74+
os.stat("poly_eval.py").st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH,
75+
)
76+
77+
run_cli(
78+
ENSEMBLE_EXPERIMENT_MODE,
79+
"--disable-monitor",
80+
"poly.ert",
81+
"--experiment-name",
82+
"test-experiment",
83+
)
84+
storage_path = ErtConfig.from_file("poly.ert").ens_path
85+
with open_storage(storage_path) as storage:
86+
experiment = storage.get_experiment_by_name("test-experiment")
87+
params = experiment.get_ensemble_by_name("default").load_parameters(
88+
"DESIGN_MATRIX"
89+
)["values"]
90+
np.testing.assert_array_equal(params[:, 0], a_values)
91+
np.testing.assert_array_equal(params[:, 1], 10 * [1])
92+
np.testing.assert_array_equal(params[:, 2], 10 * [2])
93+
94+
95+
@pytest.mark.usefixtures("copy_poly_case")
96+
@pytest.mark.parametrize(
97+
"default_values, error_msg",
98+
[
99+
([["b", 1], ["c", 2]], None),
100+
([["b", 1]], "Overlapping parameter names found in design matrix!"),
101+
],
102+
)
103+
def test_run_poly_example_with_design_matrix_and_genkw_merge(default_values, error_msg):
104+
design_matrix = "poly_design.xlsx"
105+
num_realizations = 10
106+
a_values = list(range(num_realizations))
107+
design_matrix_df = pd.DataFrame(
108+
{
109+
"REAL": list(range(num_realizations)),
110+
"a": a_values,
111+
}
112+
)
113+
default_sheet_df = pd.DataFrame(default_values)
114+
with pd.ExcelWriter(design_matrix) as xl_write:
115+
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
116+
default_sheet_df.to_excel(
117+
xl_write, index=False, sheet_name="DefaultSheet", header=False
118+
)
119+
120+
with open("poly.ert", "w", encoding="utf-8") as fout:
121+
fout.write(
122+
dedent(
123+
"""\
124+
QUEUE_OPTION LOCAL MAX_RUNNING 10
125+
RUNPATH poly_out/realization-<IENS>/iter-<ITER>
126+
NUM_REALIZATIONS 10
127+
MIN_REALIZATIONS 1
128+
GEN_DATA POLY_RES RESULT_FILE:poly.out
129+
GEN_KW COEFFS coeff_priors
130+
DESIGN_MATRIX poly_design.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultSheet
131+
INSTALL_JOB poly_eval POLY_EVAL
132+
FORWARD_MODEL poly_eval
133+
"""
134+
)
135+
)
136+
137+
with open("poly_eval.py", "w", encoding="utf-8") as f:
138+
f.write(
139+
dedent(
140+
"""\
141+
#!/usr/bin/env python
142+
import json
143+
144+
def _load_coeffs(filename):
145+
with open(filename, encoding="utf-8") as f:
146+
return json.load(f)["COEFFS"]
147+
148+
def _evaluate(coeffs, x):
149+
return coeffs["a"] * x**2 + coeffs["b"] * x + coeffs["c"]
150+
151+
if __name__ == "__main__":
152+
coeffs = _load_coeffs("parameters.json")
153+
output = [_evaluate(coeffs, x) for x in range(10)]
154+
with open("poly.out", "w", encoding="utf-8") as f:
155+
f.write("\\n".join(map(str, output)))
156+
"""
157+
)
158+
)
159+
os.chmod(
160+
"poly_eval.py",
161+
os.stat("poly_eval.py").st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH,
162+
)
163+
164+
if error_msg:
165+
with pytest.raises(ErtCliError, match=error_msg):
166+
run_cli(
167+
ENSEMBLE_EXPERIMENT_MODE,
168+
"--disable-monitor",
169+
"poly.ert",
170+
"--experiment-name",
171+
"test-experiment",
172+
)
173+
return
174+
run_cli(
175+
ENSEMBLE_EXPERIMENT_MODE,
176+
"--disable-monitor",
177+
"poly.ert",
178+
"--experiment-name",
179+
"test-experiment",
180+
)
181+
storage_path = ErtConfig.from_file("poly.ert").ens_path
182+
with open_storage(storage_path) as storage:
183+
experiment = storage.get_experiment_by_name("test-experiment")
184+
params = experiment.get_ensemble_by_name("default").load_parameters("COEFFS")[
185+
"values"
186+
]
187+
np.testing.assert_array_equal(params[:, 0], a_values)
188+
np.testing.assert_array_equal(params[:, 1], 10 * [1])
189+
np.testing.assert_array_equal(params[:, 2], 10 * [2])

‎tests/ert/unit_tests/config/test_analysis_config.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from textwrap import dedent
22

33
import hypothesis.strategies as st
4+
import pandas as pd
45
import pytest
56
from hypothesis import given
67

@@ -15,22 +16,33 @@
1516

1617

1718
def test_analysis_config_from_file_is_same_as_from_dict(monkeypatch, tmp_path):
18-
with open(tmp_path / "my_design_matrix.xlsx", "w", encoding="utf-8"):
19-
pass
19+
with pd.ExcelWriter(tmp_path / "my_design_matrix.xlsx") as xl_write:
20+
design_matrix_df = pd.DataFrame(
21+
{
22+
"REAL": [0, 1, 2],
23+
"a": [1, 2, 3],
24+
"b": [0, 2, 0],
25+
}
26+
)
27+
default_sheet_df = pd.DataFrame([["a", 1], ["b", 4]])
28+
design_matrix_df.to_excel(xl_write, index=False, sheet_name="my_sheet")
29+
default_sheet_df.to_excel(
30+
xl_write, index=False, sheet_name="my_default_sheet", header=False
31+
)
2032
monkeypatch.chdir(tmp_path)
2133
assert ErtConfig.from_file_contents(
2234
dedent(
2335
"""
24-
NUM_REALIZATIONS 10
25-
MIN_REALIZATIONS 10
36+
NUM_REALIZATIONS 3
37+
MIN_REALIZATIONS 3
2638
ANALYSIS_SET_VAR STD_ENKF ENKF_TRUNCATION 0.8
2739
DESIGN_MATRIX my_design_matrix.xlsx DESIGN_SHEET:my_sheet DEFAULT_SHEET:my_default_sheet
2840
"""
2941
)
3042
).analysis_config == AnalysisConfig.from_dict(
3143
{
32-
ConfigKeys.NUM_REALIZATIONS: 10,
33-
ConfigKeys.MIN_REALIZATIONS: "10",
44+
ConfigKeys.NUM_REALIZATIONS: 3,
45+
ConfigKeys.MIN_REALIZATIONS: "3",
3446
ConfigKeys.ANALYSIS_SET_VAR: [
3547
("STD_ENKF", "ENKF_TRUNCATION", 0.8),
3648
],

‎tests/ert/unit_tests/gui/simulation/test_run_dialog.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from queue import SimpleQueue
33
from unittest.mock import MagicMock, Mock, patch
44

5+
import pandas as pd
56
import pytest
67
from pytestqt.qtbot import QtBot
78
from qtpy import QtWidgets
@@ -732,15 +733,26 @@ def test_that_stdout_and_stderr_buttons_react_to_file_content(
732733
def test_that_design_matrix_show_parameters_button_is_visible(
733734
design_matrix_entry, qtbot: QtBot, storage
734735
):
735-
xls_filename = "design_matrix.xls"
736-
with open(f"{xls_filename}", "w", encoding="utf-8"):
737-
pass
736+
xls_filename = "design_matrix.xlsx"
737+
design_matrix_df = pd.DataFrame(
738+
{
739+
"REAL": list(range(3)),
740+
"a": [0, 1, 2],
741+
}
742+
)
743+
default_sheet_df = pd.DataFrame([["b", 1], ["c", 2]])
744+
with pd.ExcelWriter(xls_filename) as xl_write:
745+
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
746+
default_sheet_df.to_excel(
747+
xl_write, index=False, sheet_name="DefaultSheet", header=False
748+
)
749+
738750
config_file = "minimal_config.ert"
739751
with open(config_file, "w", encoding="utf-8") as f:
740752
f.write("NUM_REALIZATIONS 1")
741753
if design_matrix_entry:
742754
f.write(
743-
f"\nDESIGN_MATRIX {xls_filename} DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultValues"
755+
f"\nDESIGN_MATRIX {xls_filename} DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultSheet"
744756
)
745757

746758
args_mock = Mock()

‎tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py

+85-11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,83 @@
33
import pytest
44

55
from ert.config.design_matrix import DESIGN_MATRIX_GROUP, DesignMatrix
6+
from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition
7+
8+
9+
@pytest.mark.parametrize(
10+
"parameters, error_msg",
11+
[
12+
pytest.param(
13+
{"COEFFS": ["a", "b"]},
14+
"",
15+
id="genkw_replaced",
16+
),
17+
pytest.param(
18+
{"COEFFS": ["a"]},
19+
"Overlapping parameter names found in design matrix!",
20+
id="ValidationErrorOverlapping",
21+
),
22+
pytest.param(
23+
{"COEFFS": ["aa", "bb"], "COEFFS2": ["cc", "dd"]},
24+
"",
25+
id="DESIGN_MATRIX_GROUP",
26+
),
27+
pytest.param(
28+
{"COEFFS": ["a", "b"], "COEFFS2": ["a", "b"]},
29+
"Multiple overlapping groups with design matrix found in existing parameters!",
30+
id="ValidationErrorMultipleGroups",
31+
),
32+
],
33+
)
34+
def test_read_and_merge_with_existing_parameters(tmp_path, parameters, error_msg):
35+
extra_genkw_config = []
36+
if parameters:
37+
for group_name in parameters:
38+
extra_genkw_config.append(
39+
GenKwConfig(
40+
name=group_name,
41+
forward_init=False,
42+
template_file="",
43+
transform_function_definitions=[
44+
TransformFunctionDefinition(param, "UNIFORM", [0, 1])
45+
for param in parameters[group_name]
46+
],
47+
output_file="kw.txt",
48+
update=True,
49+
)
50+
)
51+
52+
realizations = [0, 1, 2]
53+
design_path = tmp_path / "design_matrix.xlsx"
54+
design_matrix_df = pd.DataFrame(
55+
{
56+
"REAL": realizations,
57+
"a": [1, 2, 3],
58+
"b": [0, 2, 0],
59+
}
60+
)
61+
default_sheet_df = pd.DataFrame([["a", 1], ["b", 4]])
62+
with pd.ExcelWriter(design_path) as xl_write:
63+
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
64+
default_sheet_df.to_excel(
65+
xl_write, index=False, sheet_name="DefaultValues", header=False
66+
)
67+
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
68+
if error_msg:
69+
with pytest.raises(ValueError, match=error_msg):
70+
design_matrix.merge_with_existing_parameters(extra_genkw_config)
71+
elif len(parameters) == 1:
72+
new_config_parameters, design_group = (
73+
design_matrix.merge_with_existing_parameters(extra_genkw_config)
74+
)
75+
assert len(new_config_parameters) == 0
76+
assert design_group.name == "COEFFS"
77+
elif len(parameters) == 2:
78+
new_config_parameters, design_group = (
79+
design_matrix.merge_with_existing_parameters(extra_genkw_config)
80+
)
81+
assert len(new_config_parameters) == 2
82+
assert design_group.name == DESIGN_MATRIX_GROUP
683

784

885
def test_reading_design_matrix(tmp_path):
@@ -23,10 +100,8 @@ def test_reading_design_matrix(tmp_path):
23100
xl_write, index=False, sheet_name="DefaultValues", header=False
24101
)
25102
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
26-
design_matrix.read_design_matrix()
27103
design_params = design_matrix.parameter_configuration.get(DESIGN_MATRIX_GROUP, [])
28104
assert all(param in design_params for param in ("a", "b", "c", "one", "d"))
29-
assert design_matrix.num_realizations == 3
30105
assert design_matrix.active_realizations == [True, True, False, False, True]
31106

32107

@@ -62,9 +137,9 @@ def test_reading_design_matrix_validate_reals(tmp_path, real_column, error_msg):
62137
default_sheet_df.to_excel(
63138
xl_write, index=False, sheet_name="DefaultValues", header=False
64139
)
65-
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
140+
66141
with pytest.raises(ValueError, match=error_msg):
67-
design_matrix.read_design_matrix()
142+
DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
68143

69144

70145
@pytest.mark.parametrize(
@@ -98,9 +173,9 @@ def test_reading_design_matrix_validate_headers(tmp_path, column_names, error_ms
98173
default_sheet_df.to_excel(
99174
xl_write, index=False, sheet_name="DefaultValues", header=False
100175
)
101-
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
176+
102177
with pytest.raises(ValueError, match=error_msg):
103-
design_matrix.read_design_matrix()
178+
DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
104179

105180

106181
@pytest.mark.parametrize(
@@ -134,9 +209,9 @@ def test_reading_design_matrix_validate_cells(tmp_path, values, error_msg):
134209
default_sheet_df.to_excel(
135210
xl_write, index=False, sheet_name="DefaultValues", header=False
136211
)
137-
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
212+
138213
with pytest.raises(ValueError, match=error_msg):
139-
design_matrix.read_design_matrix()
214+
DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
140215

141216

142217
@pytest.mark.parametrize(
@@ -180,9 +255,9 @@ def test_reading_default_sheet_validation(tmp_path, data, error_msg):
180255
default_sheet_df.to_excel(
181256
xl_write, index=False, sheet_name="DefaultValues", header=False
182257
)
183-
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
258+
184259
with pytest.raises(ValueError, match=error_msg):
185-
design_matrix.read_design_matrix()
260+
DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
186261

187262

188263
def test_default_values_used(tmp_path):
@@ -202,7 +277,6 @@ def test_default_values_used(tmp_path):
202277
xl_write, index=False, sheet_name="DefaultValues", header=False
203278
)
204279
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
205-
design_matrix.read_design_matrix()
206280
df = design_matrix.design_matrix_df
207281
np.testing.assert_equal(df[DESIGN_MATRIX_GROUP, "one"], np.array([1, 1, 1, 1]))
208282
np.testing.assert_equal(df[DESIGN_MATRIX_GROUP, "b"], np.array([0, 2, 0, 1]))

‎tests/ert/unit_tests/test_libres_facade.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
from datetime import datetime, timedelta
33
from textwrap import dedent
44

5+
import numpy as np
56
import pytest
7+
from pandas import ExcelWriter
68
from pandas.core.frame import DataFrame
79
from resdata.summary import Summary
810

911
from ert.config import ErtConfig
10-
from ert.enkf_main import sample_prior
12+
from ert.config.design_matrix import DESIGN_MATRIX_GROUP, DesignMatrix
13+
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
1114
from ert.libres_facade import LibresFacade
1215
from ert.storage import open_storage
1316

@@ -253,3 +256,52 @@ def test_load_gen_kw_not_sorted(storage, tmpdir, snapshot):
253256

254257
data = ensemble.load_all_gen_kw_data()
255258
snapshot.assert_match(data.round(12).to_csv(), "gen_kw_unsorted")
259+
260+
261+
@pytest.mark.parametrize(
262+
"reals, expect_error",
263+
[
264+
pytest.param(
265+
list(range(10)),
266+
False,
267+
id="correct_active_realizations",
268+
),
269+
pytest.param([10, 11], True, id="incorrect_active_realizations"),
270+
],
271+
)
272+
def test_save_parameters_to_storage_from_design_dataframe(
273+
tmp_path, reals, expect_error
274+
):
275+
design_path = tmp_path / "design_matrix.xlsx"
276+
ensemble_size = 10
277+
a_values = np.random.default_rng().uniform(-5, 5, 10)
278+
b_values = np.random.default_rng().uniform(-5, 5, 10)
279+
c_values = np.random.default_rng().uniform(-5, 5, 10)
280+
design_matrix_df = DataFrame({"a": a_values, "b": b_values, "c": c_values})
281+
with ExcelWriter(design_path) as xl_write:
282+
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
283+
DataFrame().to_excel(
284+
xl_write, index=False, sheet_name="DefaultValues", header=False
285+
)
286+
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
287+
with open_storage(tmp_path / "storage", mode="w") as storage:
288+
experiment_id = storage.create_experiment(
289+
parameters=[design_matrix.parameter_configuration[DESIGN_MATRIX_GROUP]]
290+
)
291+
ensemble = storage.create_ensemble(
292+
experiment_id, name="default", ensemble_size=ensemble_size
293+
)
294+
if expect_error:
295+
with pytest.raises(KeyError):
296+
save_design_matrix_to_ensemble(
297+
design_matrix.design_matrix_df, ensemble, reals
298+
)
299+
else:
300+
save_design_matrix_to_ensemble(
301+
design_matrix.design_matrix_df, ensemble, reals
302+
)
303+
params = ensemble.load_parameters(DESIGN_MATRIX_GROUP)["values"]
304+
all(params.names.values == ["a", "b", "c"])
305+
np.testing.assert_array_almost_equal(params[:, 0], a_values)
306+
np.testing.assert_array_almost_equal(params[:, 1], b_values)
307+
np.testing.assert_array_almost_equal(params[:, 2], c_values)

0 commit comments

Comments
 (0)
Please sign in to comment.