8
8
import pandas as pd
9
9
from pandas .api .types import is_integer_dtype
10
10
11
- from ert .config .gen_kw_config import GenKwConfig , TransformFunctionDefinition
12
-
13
11
from ._option_dict import option_dict
14
12
from .parsing import ConfigValidationError , ErrorInfo
13
+ from .scalar_parameter import (
14
+ DataSource ,
15
+ ScalarParameter ,
16
+ ScalarParameters ,
17
+ TransRawSettings ,
18
+ )
15
19
16
20
if TYPE_CHECKING :
17
- from ert . config import ParameterConfig
21
+ pass
18
22
19
23
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"
20
24
@@ -32,7 +36,7 @@ def __post_init__(self) -> None:
32
36
(
33
37
self .active_realizations ,
34
38
self .design_matrix_df ,
35
- self .parameter_configuration ,
39
+ self .scalars ,
36
40
) = self .read_and_validate_design_matrix ()
37
41
except (ValueError , AttributeError ) as exc :
38
42
raise ConfigValidationError .with_context (
@@ -102,66 +106,54 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:
102
106
except ValueError as exc :
103
107
errors .append (ErrorInfo (f"Error when merging design matrices { exc } !" ))
104
108
105
- for tfd in dm_other .parameter_configuration . transform_function_definitions :
106
- self .parameter_configuration . transform_function_definitions . append (tfd )
109
+ for param in dm_other .scalars :
110
+ self .scalars . append (param )
107
111
108
112
if errors :
109
113
raise ConfigValidationError .from_collected (errors )
110
114
111
115
def merge_with_existing_parameters (
112
- self , existing_parameters : list [ ParameterConfig ]
113
- ) -> tuple [ list [ ParameterConfig ], GenKwConfig ] :
116
+ self , existing_scalars : ScalarParameters
117
+ ) -> ScalarParameters :
114
118
"""
115
119
This method merges the design matrix parameters with the existing parameters and
116
- returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
117
- GEN_KW group that was dropped will acquire a new name from the design matrix group.
118
- Additionally, the ParameterConfig which is the design matrix group is returned separately.
119
-
120
+ returns the new list of existing parameters.
120
121
Args:
121
- existing_parameters (List[ParameterConfig] ): List of existing parameters
122
+ existing_scalars (ScalarParameters ): existing scalar parameters
122
123
123
- Raises:
124
- ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
125
124
126
125
Returns:
127
- tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
126
+ ScalarParameters: new set of ScalarParameters
128
127
"""
129
128
130
- new_param_config : list [ParameterConfig ] = []
131
-
132
- design_parameter_group = self .parameter_configuration
133
- design_keys = [e .name for e in design_parameter_group .transform_functions ]
129
+ all_params : list [ScalarParameter ] = []
134
130
135
- design_group_added = False
136
- for parameter_group in existing_parameters :
137
- if not isinstance (parameter_group , GenKwConfig ):
138
- new_param_config += [parameter_group ]
131
+ overlap_set = set ()
132
+ for existing_parameter in existing_scalars .scalars :
133
+ if existing_parameter .input_source == DataSource .DESIGN_MATRIX :
139
134
continue
140
- existing_keys = [e .name for e in parameter_group .transform_functions ]
141
- if set (existing_keys ) == set (design_keys ):
142
- if design_group_added :
143
- raise ConfigValidationError (
144
- "Multiple overlapping groups with design matrix found in existing parameters!\n "
145
- f"{ design_parameter_group .name } and { parameter_group .name } "
146
- )
147
-
148
- design_parameter_group .name = parameter_group .name
149
- design_parameter_group .template_file = parameter_group .template_file
150
- design_parameter_group .output_file = parameter_group .output_file
151
- design_group_added = True
152
- elif set (design_keys ) & set (existing_keys ):
153
- raise ConfigValidationError (
154
- "Overlapping parameter names found in design matrix!\n "
155
- f"{ DESIGN_MATRIX_GROUP } :{ design_keys } \n { parameter_group .name } :{ existing_keys } "
156
- "\n They need to match exactly or not at all."
157
- )
158
- else :
159
- new_param_config += [parameter_group ]
160
- return new_param_config , design_parameter_group
135
+ overlap = False
136
+ for parameter_design in self .scalars :
137
+ if existing_parameter .param_name == parameter_design .param_name :
138
+ parameter_design .group_name = existing_parameter .group_name
139
+ parameter_design .template_file = existing_parameter .template_file
140
+ parameter_design .output_file = existing_parameter .output_file
141
+ all_params .append (parameter_design )
142
+ overlap = True
143
+ overlap_set .add (existing_parameter .param_name )
144
+ break
145
+ if not overlap :
146
+ all_params .append (existing_parameter )
147
+
148
+ for parameter_design in self .scalars :
149
+ if parameter_design .param_name not in overlap_set :
150
+ all_params .append (parameter_design )
151
+
152
+ return ScalarParameters (scalars = all_params )
161
153
162
154
def read_and_validate_design_matrix (
163
155
self ,
164
- ) -> tuple [list [bool ], pd .DataFrame , GenKwConfig ]:
156
+ ) -> tuple [list [bool ], pd .DataFrame , list [ ScalarParameter ] ]:
165
157
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
166
158
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
167
159
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
@@ -207,29 +199,25 @@ def read_and_validate_design_matrix(
207
199
208
200
design_matrix_df = pd .concat ([design_matrix_df , default_df ], axis = 1 )
209
201
210
- transform_function_definitions : list [TransformFunctionDefinition ] = []
202
+ scalars : list [ScalarParameter ] = []
211
203
for parameter in design_matrix_df .columns :
212
- transform_function_definitions .append (
213
- TransformFunctionDefinition (
214
- name = parameter ,
215
- param_name = "RAW" ,
216
- values = [],
204
+ scalars .append (
205
+ ScalarParameter (
206
+ param_name = parameter ,
207
+ group_name = DESIGN_MATRIX_GROUP ,
208
+ input_source = DataSource .DESIGN_MATRIX ,
209
+ distribution = TransRawSettings (),
210
+ template_file = None ,
211
+ output_file = None ,
212
+ update = False ,
217
213
)
218
214
)
219
- parameter_configuration = GenKwConfig (
220
- name = DESIGN_MATRIX_GROUP ,
221
- forward_init = False ,
222
- template_file = None ,
223
- output_file = None ,
224
- transform_function_definitions = transform_function_definitions ,
225
- update = False ,
226
- )
227
215
228
216
reals = design_matrix_df .index .tolist ()
229
217
return (
230
218
[x in reals for x in range (max (reals ) + 1 )],
231
219
design_matrix_df ,
232
- parameter_configuration ,
220
+ scalars ,
233
221
)
234
222
235
223
@staticmethod
0 commit comments