Skip to content

Commit b0d5e19

Browse files
committed
Add support for design matrix in ensemble smoother
1 parent ca6b7e1 commit b0d5e19

File tree

5 files changed

+60
-14
lines changed

5 files changed

+60
-14
lines changed

src/ert/gui/simulation/ensemble_smoother_panel.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
TargetEnsembleModel,
1616
TextModel,
1717
)
18+
from ert.gui.tools.design_matrix.design_matrix_panel import DesignMatrixPanel
1819
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE
1920
from ert.run_models import EnsembleSmoother
20-
from ert.validation import ProperNameFormatArgument, RangeStringArgument
21-
from ert.validation.proper_name_argument import ExperimentValidation
21+
from ert.validation import (
22+
ExperimentValidation,
23+
ProperNameFormatArgument,
24+
RangeStringArgument,
25+
)
2226

2327
from .experiment_config_panel import ExperimentConfigPanel
2428

@@ -44,8 +48,8 @@ def __init__(
4448
) -> None:
4549
super().__init__(EnsembleSmoother)
4650
self.notifier = notifier
47-
layout = QFormLayout()
4851

52+
layout = QFormLayout()
4953
self.setObjectName("ensemble_smoother_panel")
5054

5155
self._experiment_name_field = StringBox(
@@ -63,6 +67,7 @@ def __init__(
6367

6468
runpath_label = CopyableLabel(text=run_path)
6569
layout.addRow("Runpath:", runpath_label)
70+
6671
number_of_realizations_label = QLabel(f"<b>{ensemble_size}</b>")
6772
layout.addRow(QLabel("Number of realizations:"), number_of_realizations_label)
6873

@@ -89,6 +94,15 @@ def __init__(
8994
self._active_realizations_field.setValidator(RangeStringArgument(ensemble_size))
9095
layout.addRow("Active realizations", self._active_realizations_field)
9196

97+
design_matrix = analysis_config.design_matrix
98+
if design_matrix is not None:
99+
layout.addRow(
100+
"Design Matrix",
101+
DesignMatrixPanel.get_design_matrix_button(
102+
self._active_realizations_field, design_matrix
103+
),
104+
)
105+
92106
self.setLayout(layout)
93107

94108
self._experiment_name_field.getValidationSupport().validationChanged.connect(

src/ert/run_models/ensemble_smoother.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
import numpy as np
99

1010
from ert.config import ErtConfig, ESSettings, HookRuntime, UpdateSettings
11-
from ert.enkf_main import sample_prior
11+
from ert.config.parsing.config_errors import ConfigValidationError
12+
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
1213
from ert.ensemble_evaluator import EvaluatorServerConfig
1314
from ert.storage import Storage
1415
from ert.trace import tracer
1516

1617
from ..run_arg import create_run_arguments
17-
from .base_run_model import StatusEvents, UpdateRunModel
18+
from .base_run_model import ErtRunError, StatusEvents, UpdateRunModel
1819

1920
if TYPE_CHECKING:
2021
from ert.config import QueueConfig
@@ -66,6 +67,7 @@ def __init__(
6667
self.support_restart = False
6768

6869
self._parameter_configuration = config.ensemble_config.parameter_configuration
70+
self._design_matrix = config.analysis_config.design_matrix
6971
self._observations = config.observations
7072
self._response_configuration = config.ensemble_config.response_configuration
7173

@@ -74,14 +76,27 @@ def run_experiment(
7476
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
7577
) -> None:
7678
self.log_at_startup()
79+
80+
parameters_config = self._parameter_configuration
81+
design_matrix = self._design_matrix
82+
design_matrix_group = None
83+
if design_matrix is not None:
84+
try:
85+
parameters_config, design_matrix_group = (
86+
design_matrix.merge_with_existing_parameters(parameters_config)
87+
)
88+
except ConfigValidationError as exc:
89+
raise ErtRunError(str(exc)) from exc
90+
7791
self.restart = restart
7892
self.run_workflows(
7993
HookRuntime.PRE_EXPERIMENT,
8094
fixtures={"random_seed": self.random_seed},
8195
)
8296
ensemble_format = self.target_ensemble_format
8397
experiment = self._storage.create_experiment(
84-
parameters=self._parameter_configuration,
98+
parameters=parameters_config
99+
+ ([design_matrix_group] if design_matrix_group else []),
85100
observations=self._observations,
86101
responses=self._response_configuration,
87102
name=self.experiment_name,
@@ -106,6 +121,13 @@ def run_experiment(
106121
random_seed=self.random_seed,
107122
)
108123

124+
if design_matrix_group is not None and design_matrix is not None:
125+
save_design_matrix_to_ensemble(
126+
design_matrix.design_matrix_df,
127+
prior,
128+
np.where(self.active_realizations)[0],
129+
design_matrix_group.name,
130+
)
109131
self._evaluate_and_postprocess(
110132
prior_args,
111133
prior,

src/ert/run_models/model_factory.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,7 @@ def _setup_ensemble_smoother(
187187
update_settings: UpdateSettings,
188188
status_queue: SimpleQueue[StatusEvents],
189189
) -> EnsembleSmoother:
190-
active_realizations = _realizations(
191-
args, config.model_config.num_realizations
192-
).tolist()
190+
active_realizations = _get_active_realizations_list(args, config)
193191
if len(active_realizations) < 2:
194192
raise ConfigValidationError(
195193
"Number of active realizations must be at least 2 for an update step"

tests/ert/ui_tests/cli/analysis/test_design_matrix.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
from ert.cli.main import ErtCliError
1313
from ert.config import ErtConfig
14-
from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE, ES_MDA_MODE
14+
from ert.mode_definitions import (
15+
ENSEMBLE_EXPERIMENT_MODE,
16+
ENSEMBLE_SMOOTHER_MODE,
17+
ES_MDA_MODE,
18+
)
1519
from ert.storage import open_storage
1620
from tests.ert.ui_tests.cli.run_cli import run_cli
1721

@@ -304,7 +308,14 @@ def _evaluate(coeffs, x):
304308

305309

306310
@pytest.mark.usefixtures("copy_poly_case")
307-
def test_design_matrix_on_esmda():
311+
@pytest.mark.parametrize(
312+
"experiment_mode, ensemble_name, iterations",
313+
[
314+
(ES_MDA_MODE, "default_", 4),
315+
(ENSEMBLE_SMOOTHER_MODE, "iter-", 2),
316+
],
317+
)
318+
def test_design_matrix_on_esmda(experiment_mode, ensemble_name, iterations):
308319
design_path = "design_matrix.xlsx"
309320
reals = range(10)
310321
values = [random.uniform(0, 2) for _ in reals]
@@ -375,7 +386,7 @@ def _evaluate(coeffs, x):
375386
f.write("c UNIFORM 0 5")
376387

377388
run_cli(
378-
ES_MDA_MODE,
389+
experiment_mode,
379390
"--disable-monitoring",
380391
"poly.ert",
381392
"--experiment-name",
@@ -385,8 +396,8 @@ def _evaluate(coeffs, x):
385396
coeffs_a_previous = None
386397
with open_storage(storage_path) as storage:
387398
experiment = storage.get_experiment_by_name("test-experiment")
388-
for i in range(4):
389-
ensemble = experiment.get_ensemble_by_name(f"default_{i}")
399+
for i in range(iterations):
400+
ensemble = experiment.get_ensemble_by_name(f"{ensemble_name}{i}")
390401

391402
# coeffs_a should be different in all realizations
392403
coeffs_a = ensemble.load_parameters("COEFFS_A")["values"].values.flatten()

tests/ert/unit_tests/cli/test_model_hook_order.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def test_hook_call_order_ensemble_smoother(monkeypatch):
134134
)
135135
test_class.run_ensemble_evaluator = MagicMock(return_value=[0])
136136
test_class._storage = storage_mock
137+
test_class._design_matrix = None
137138
test_class.run_experiment(MagicMock())
138139

139140
assert run_wfs_mock.mock_calls == EXPECTED_CALL_ORDER

0 commit comments

Comments
 (0)