3
3
import pytest
4
4
5
5
from ert .config .design_matrix import DESIGN_MATRIX_GROUP , DesignMatrix
6
+ from ert .config .gen_kw_config import GenKwConfig , TransformFunctionDefinition
7
+
8
+
9
+ @pytest .mark .parametrize (
10
+ "parameters, error_msg" ,
11
+ [
12
+ pytest .param (
13
+ {"COEFFS" : ["a" , "b" ]},
14
+ "" ,
15
+ id = "genkw_replaced" ,
16
+ ),
17
+ pytest .param (
18
+ {"COEFFS" : ["a" ]},
19
+ "Overlapping parameter names found in design matrix!" ,
20
+ id = "ValidationErrorOverlapping" ,
21
+ ),
22
+ pytest .param (
23
+ {"COEFFS" : ["aa" , "bb" ], "COEFFS2" : ["cc" , "dd" ]},
24
+ "" ,
25
+ id = "DESIGN_MATRIX_GROUP" ,
26
+ ),
27
+ pytest .param (
28
+ {"COEFFS" : ["a" , "b" ], "COEFFS2" : ["a" , "b" ]},
29
+ "Multiple overlapping groups with design matrix found in existing parameters!" ,
30
+ id = "ValidationErrorMultipleGroups" ,
31
+ ),
32
+ ],
33
+ )
34
+ def test_read_and_merge_with_existing_parameters (tmp_path , parameters , error_msg ):
35
+ extra_genkw_config = []
36
+ if parameters :
37
+ for group_name in parameters :
38
+ extra_genkw_config .append (
39
+ GenKwConfig (
40
+ name = group_name ,
41
+ forward_init = False ,
42
+ template_file = "" ,
43
+ transform_function_definitions = [
44
+ TransformFunctionDefinition (param , "UNIFORM" , [0 , 1 ])
45
+ for param in parameters [group_name ]
46
+ ],
47
+ output_file = "kw.txt" ,
48
+ update = True ,
49
+ )
50
+ )
51
+
52
+ realizations = [0 , 1 , 2 ]
53
+ design_path = tmp_path / "design_matrix.xlsx"
54
+ design_matrix_df = pd .DataFrame (
55
+ {
56
+ "REAL" : realizations ,
57
+ "a" : [1 , 2 , 3 ],
58
+ "b" : [0 , 2 , 0 ],
59
+ }
60
+ )
61
+ default_sheet_df = pd .DataFrame ([["a" , 1 ], ["b" , 4 ]])
62
+ with pd .ExcelWriter (design_path ) as xl_write :
63
+ design_matrix_df .to_excel (xl_write , index = False , sheet_name = "DesignSheet01" )
64
+ default_sheet_df .to_excel (
65
+ xl_write , index = False , sheet_name = "DefaultValues" , header = False
66
+ )
67
+ design_matrix = DesignMatrix (design_path , "DesignSheet01" , "DefaultValues" )
68
+ if error_msg :
69
+ with pytest .raises (ValueError , match = error_msg ):
70
+ design_matrix .merge_with_existing_parameters (extra_genkw_config )
71
+ elif len (parameters ) == 1 :
72
+ new_config_parameters , design_group = (
73
+ design_matrix .merge_with_existing_parameters (extra_genkw_config )
74
+ )
75
+ assert len (new_config_parameters ) == 0
76
+ assert design_group .name == "COEFFS"
77
+ elif len (parameters ) == 2 :
78
+ new_config_parameters , design_group = (
79
+ design_matrix .merge_with_existing_parameters (extra_genkw_config )
80
+ )
81
+ assert len (new_config_parameters ) == 2
82
+ assert design_group .name == DESIGN_MATRIX_GROUP
6
83
7
84
8
85
def test_reading_design_matrix (tmp_path ):
@@ -23,10 +100,8 @@ def test_reading_design_matrix(tmp_path):
23
100
xl_write , index = False , sheet_name = "DefaultValues" , header = False
24
101
)
25
102
design_matrix = DesignMatrix (design_path , "DesignSheet01" , "DefaultValues" )
26
- design_matrix .read_design_matrix ()
27
103
design_params = design_matrix .parameter_configuration .get (DESIGN_MATRIX_GROUP , [])
28
104
assert all (param in design_params for param in ("a" , "b" , "c" , "one" , "d" ))
29
- assert design_matrix .num_realizations == 3
30
105
assert design_matrix .active_realizations == [True , True , False , False , True ]
31
106
32
107
@@ -62,9 +137,9 @@ def test_reading_design_matrix_validate_reals(tmp_path, real_column, error_msg):
62
137
default_sheet_df .to_excel (
63
138
xl_write , index = False , sheet_name = "DefaultValues" , header = False
64
139
)
65
- design_matrix = DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
140
+
66
141
with pytest .raises (ValueError , match = error_msg ):
67
- design_matrix . read_design_matrix ( )
142
+ DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
68
143
69
144
70
145
@pytest .mark .parametrize (
@@ -98,9 +173,9 @@ def test_reading_design_matrix_validate_headers(tmp_path, column_names, error_ms
98
173
default_sheet_df .to_excel (
99
174
xl_write , index = False , sheet_name = "DefaultValues" , header = False
100
175
)
101
- design_matrix = DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
176
+
102
177
with pytest .raises (ValueError , match = error_msg ):
103
- design_matrix . read_design_matrix ( )
178
+ DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
104
179
105
180
106
181
@pytest .mark .parametrize (
@@ -134,9 +209,9 @@ def test_reading_design_matrix_validate_cells(tmp_path, values, error_msg):
134
209
default_sheet_df .to_excel (
135
210
xl_write , index = False , sheet_name = "DefaultValues" , header = False
136
211
)
137
- design_matrix = DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
212
+
138
213
with pytest .raises (ValueError , match = error_msg ):
139
- design_matrix . read_design_matrix ( )
214
+ DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
140
215
141
216
142
217
@pytest .mark .parametrize (
@@ -180,9 +255,9 @@ def test_reading_default_sheet_validation(tmp_path, data, error_msg):
180
255
default_sheet_df .to_excel (
181
256
xl_write , index = False , sheet_name = "DefaultValues" , header = False
182
257
)
183
- design_matrix = DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
258
+
184
259
with pytest .raises (ValueError , match = error_msg ):
185
- design_matrix . read_design_matrix ( )
260
+ DesignMatrix ( design_path , "DesignSheet01" , "DefaultValues" )
186
261
187
262
188
263
def test_default_values_used (tmp_path ):
@@ -202,7 +277,6 @@ def test_default_values_used(tmp_path):
202
277
xl_write , index = False , sheet_name = "DefaultValues" , header = False
203
278
)
204
279
design_matrix = DesignMatrix (design_path , "DesignSheet01" , "DefaultValues" )
205
- design_matrix .read_design_matrix ()
206
280
df = design_matrix .design_matrix_df
207
281
np .testing .assert_equal (df [DESIGN_MATRIX_GROUP , "one" ], np .array ([1 , 1 , 1 , 1 ]))
208
282
np .testing .assert_equal (df [DESIGN_MATRIX_GROUP , "b" ], np .array ([0 , 2 , 0 , 1 ]))
0 commit comments