8
8
import pandas as pd
9
9
from pandas .api .types import is_integer_dtype
10
10
11
- from ert .config .gen_kw_config import GenKwConfig , TransformFunctionDefinition
12
-
13
11
from ._option_dict import option_dict
14
12
from .parsing import ConfigValidationError , ErrorInfo
13
+ from .scalar_parameter import (
14
+ SCALAR_PARAMETERS_NAME ,
15
+ DataSource ,
16
+ ScalarParameter ,
17
+ ScalarParameters ,
18
+ TransRawSettings ,
19
+ )
15
20
16
21
if TYPE_CHECKING :
17
- from ert . config import ParameterConfig
22
+ pass
18
23
19
24
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"
20
25
@@ -32,7 +37,7 @@ def __post_init__(self) -> None:
32
37
(
33
38
self .active_realizations ,
34
39
self .design_matrix_df ,
35
- self .parameter_configuration ,
40
+ self .scalars ,
36
41
) = self .read_design_matrix ()
37
42
except (ValueError , AttributeError ) as exc :
38
43
raise ConfigValidationError .with_context (
@@ -102,64 +107,57 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:
102
107
except ValueError as exc :
103
108
errors .append (ErrorInfo (f"Error when merging design matrices { exc } !" ))
104
109
105
- for tfd in dm_other .parameter_configuration . transform_function_definitions :
106
- self .parameter_configuration . transform_function_definitions . append (tfd )
110
+ for param in dm_other .scalars :
111
+ self .scalars . append (param )
107
112
108
113
if errors :
109
114
raise ConfigValidationError .from_collected (errors )
110
115
111
116
def merge_with_existing_parameters (
112
- self , existing_parameters : list [ ParameterConfig ]
113
- ) -> tuple [ list [ ParameterConfig ], GenKwConfig ] :
117
+ self , existing_scalars : ScalarParameters
118
+ ) -> ScalarParameters :
114
119
"""
115
120
This method merges the design matrix parameters with the existing parameters and
116
- returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
117
- GEN_KW group that was dropped will acquire a new name from the design matrix group.
118
- Additionally, the ParameterConfig which is the design matrix group is returned separately.
119
-
121
+ returns the new list of existing parameters.
120
122
Args:
121
- existing_parameters (List[ParameterConfig] ): List of existing parameters
123
+ existing_scalars (ScalarParameters ): existing scalar parameters
122
124
123
- Raises:
124
- ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
125
125
126
126
Returns:
127
- tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
127
+ ScalarParameters: new set of ScalarParameters
128
128
"""
129
129
130
- new_param_config : list [ParameterConfig ] = []
130
+ all_params : list [ScalarParameter ] = []
131
131
132
- design_parameter_group = self .parameter_configuration
133
- design_keys = [e .name for e in design_parameter_group .transform_functions ]
134
-
135
- design_group_added = False
136
- for parameter_group in existing_parameters :
137
- if not isinstance (parameter_group , GenKwConfig ):
138
- new_param_config += [parameter_group ]
132
+ overlap_set = set ()
133
+ for parameter_sampled in existing_scalars .scalars :
134
+ if parameter_sampled .input_source == DataSource .DESIGN_MATRIX :
139
135
continue
140
- existing_keys = [e .name for e in parameter_group .transform_functions ]
141
- if set (existing_keys ) == set (design_keys ):
142
- if design_group_added :
143
- raise ConfigValidationError (
144
- "Multiple overlapping groups with design matrix found in existing parameters!\n "
145
- f"{ design_parameter_group .name } and { parameter_group .name } "
146
- )
147
-
148
- design_parameter_group .name = parameter_group .name
149
- design_group_added = True
150
- elif set (design_keys ) & set (existing_keys ):
151
- raise ConfigValidationError (
152
- "Overlapping parameter names found in design matrix!\n "
153
- f"{ DESIGN_MATRIX_GROUP } :{ design_keys } \n { parameter_group .name } :{ existing_keys } "
154
- "\n They need to match exactly or not at all."
155
- )
156
- else :
157
- new_param_config += [parameter_group ]
158
- return new_param_config , design_parameter_group
136
+ overlap = False
137
+ for parameter_design in self .scalars :
138
+ if parameter_sampled .param_name == parameter_design .param_name :
139
+ parameter_design .group_name = parameter_sampled .group_name
140
+ all_params .append (parameter_design )
141
+ overlap = True
142
+ overlap_set .add (parameter_sampled .param_name )
143
+ break
144
+ if not overlap :
145
+ all_params .append (parameter_sampled )
146
+
147
+ for parameter_design in self .scalars :
148
+ if parameter_design .param_name not in overlap_set :
149
+ all_params .append (parameter_design )
150
+
151
+ return ScalarParameters (
152
+ name = SCALAR_PARAMETERS_NAME ,
153
+ forward_init = False ,
154
+ update = True ,
155
+ scalars = all_params ,
156
+ )
159
157
160
158
def read_design_matrix (
161
159
self ,
162
- ) -> tuple [list [bool ], pd .DataFrame , GenKwConfig ]:
160
+ ) -> tuple [list [bool ], pd .DataFrame , list [ ScalarParameter ] ]:
163
161
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
164
162
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
165
163
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
@@ -205,29 +203,25 @@ def read_design_matrix(
205
203
206
204
design_matrix_df = pd .concat ([design_matrix_df , default_df ], axis = 1 )
207
205
208
- transform_function_definitions : list [TransformFunctionDefinition ] = []
206
+ scalars : list [ScalarParameter ] = []
209
207
for parameter in design_matrix_df .columns :
210
- transform_function_definitions .append (
211
- TransformFunctionDefinition (
212
- name = parameter ,
213
- param_name = "RAW" ,
214
- values = [],
208
+ scalars .append (
209
+ ScalarParameter (
210
+ param_name = parameter ,
211
+ group_name = DESIGN_MATRIX_GROUP ,
212
+ input_source = DataSource .DESIGN_MATRIX ,
213
+ distribution = TransRawSettings (),
214
+ template_file = None ,
215
+ output_file = None ,
216
+ update = False ,
215
217
)
216
218
)
217
- parameter_configuration = GenKwConfig (
218
- name = DESIGN_MATRIX_GROUP ,
219
- forward_init = False ,
220
- template_file = None ,
221
- output_file = None ,
222
- transform_function_definitions = transform_function_definitions ,
223
- update = False ,
224
- )
225
219
226
220
reals = design_matrix_df .index .tolist ()
227
221
return (
228
222
[x in reals for x in range (max (reals ) + 1 )],
229
223
design_matrix_df ,
230
- parameter_configuration ,
224
+ scalars ,
231
225
)
232
226
233
227
@staticmethod
0 commit comments