Skip to content

Commit ddcf1e0

Browse files
committed
Sync with main
1 parent 761d134 commit ddcf1e0

File tree

4 files changed

+38
-35
lines changed

4 files changed

+38
-35
lines changed

src/ert/config/ert_config.py

+8-15
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
ForwardModelStepPlugin,
3737
ForwardModelStepValidationError,
3838
)
39-
from .gen_kw_config import GenKwConfig
4039
from .model_config import ModelConfig
4140
from .observation_vector import ObsVector
4241
from .observations import EnkfObs
@@ -57,19 +56,16 @@
5756
parse_contents,
5857
read_file,
5958
)
60-
from .parsing import (
61-
parse as parse_config,
62-
)
59+
from .parsing import parse as parse_config
6360
from .parsing.observations_parser import (
6461
GenObsValues,
6562
HistoryValues,
6663
ObservationConfigError,
6764
SummaryValues,
6865
)
69-
from .parsing.observations_parser import (
70-
parse as parse_observations,
71-
)
66+
from .parsing.observations_parser import parse as parse_observations
7267
from .queue_config import QueueConfig
68+
from .scalar_parameter import ScalarParameters
7369
from .workflow import Workflow
7470
from .workflow_job import ErtScriptLoadFailure, WorkflowJob
7571

@@ -849,17 +845,14 @@ def from_dict(cls, config_dict) -> Self:
849845
raise ConfigValidationError.from_collected(errors)
850846

851847
if dm := analysis_config.design_matrix:
852-
dm_params = [
853-
x.name
854-
for x in dm.parameter_configuration.transform_function_definitions
855-
]
848+
dm_params = [x.param_name for x in dm.scalars]
856849
for group_name, config in ensemble_config.parameter_configs.items():
857850
overlapping = []
858-
if not isinstance(config, GenKwConfig):
851+
if not isinstance(config, ScalarParameters):
859852
continue
860-
for transform_definition in config.transform_function_definitions:
861-
if transform_definition.name in dm_params:
862-
overlapping.append(transform_definition.name)
853+
for param in config.scalars:
854+
if param.param_name in dm_params:
855+
overlapping.append(param.param_name)
863856
if overlapping:
864857
ConfigWarning.warn(
865858
f"Parameters {overlapping} from GEN_KW group '{group_name}' "

src/ert/config/scalar_parameter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def sample_or_load(
367367
)
368368
params["realization"] = real
369369
if design_matrix_df is not None:
370-
row = design_matrix_df.loc[real]["DESIGN_MATRIX"]
370+
row = design_matrix_df.loc[real]
371371
for parameter in self.scalars:
372372
if parameter.input_source == DataSource.DESIGN_MATRIX:
373373
value = row[parameter.param_name]

src/ert/run_models/ensemble_experiment.py

-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def run_experiment(
102102
else:
103103
parameters_config.append(param)
104104
self._parameter_configuration = parameters_config
105-
parameters_config = self._parameter_configuration
106105
if not restart:
107106
self.run_workflows(
108107
HookRuntime.PRE_EXPERIMENT,

src/ert/run_models/multiple_data_assimilation.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,16 @@
88

99
import numpy as np
1010

11-
from ert.config import ErtConfig, ESSettings, HookRuntime, UpdateSettings
11+
from ert.config import (
12+
ErtConfig,
13+
ESSettings,
14+
HookRuntime,
15+
ParameterConfig,
16+
ScalarParameters,
17+
UpdateSettings,
18+
)
1219
from ert.config.parsing.config_errors import ConfigValidationError
13-
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
20+
from ert.enkf_main import sample_prior
1421
from ert.ensemble_evaluator import EvaluatorServerConfig
1522
from ert.storage import Ensemble, Storage
1623
from ert.trace import tracer
@@ -103,14 +110,21 @@ def run_experiment(
103110

104111
parameters_config = self._parameter_configuration
105112
design_matrix = self._design_matrix
106-
design_matrix_group = None
113+
parameters_config: list[ParameterConfig] = []
114+
design_matrix = self._design_matrix
107115
if design_matrix is not None:
108-
try:
109-
parameters_config, design_matrix_group = (
110-
design_matrix.merge_with_existing_parameters(parameters_config)
111-
)
112-
except ConfigValidationError as exc:
113-
raise ErtRunError(str(exc)) from exc
116+
for param in self._parameter_configuration:
117+
if isinstance(param, ScalarParameters):
118+
try:
119+
new_scalar_config = (
120+
design_matrix.merge_with_existing_parameters(param)
121+
)
122+
parameters_config.append(new_scalar_config)
123+
except ConfigValidationError as exc:
124+
raise ErtRunError(str(exc)) from exc
125+
else:
126+
parameters_config.append(param)
127+
self._parameter_configuration = parameters_config
114128

115129
self.restart = restart
116130
if self.restart_run:
@@ -138,8 +152,7 @@ def run_experiment(
138152
)
139153
sim_args = {"weights": self._relative_weights}
140154
experiment = self._storage.create_experiment(
141-
parameters=parameters_config
142-
+ ([design_matrix_group] if design_matrix_group else []),
155+
parameters=parameters_config,
143156
observations=self._observations,
144157
responses=self._response_configuration,
145158
simulation_arguments=sim_args,
@@ -164,15 +177,13 @@ def run_experiment(
164177
prior,
165178
np.where(self.active_realizations)[0],
166179
random_seed=self.random_seed,
180+
design_matrix_df=(
181+
design_matrix.design_matrix_df
182+
if design_matrix is not None
183+
else None
184+
),
167185
)
168186

169-
if design_matrix_group is not None and design_matrix is not None:
170-
save_design_matrix_to_ensemble(
171-
design_matrix.design_matrix_df,
172-
prior,
173-
np.where(self.active_realizations)[0],
174-
design_matrix_group.name,
175-
)
176187
self._evaluate_and_postprocess(
177188
prior_args,
178189
prior,

0 commit comments

Comments
 (0)