1
1
import math
2
+ import os
2
3
import warnings
3
4
from dataclasses import dataclass
5
+ from enum import StrEnum
4
6
from pathlib import Path
5
- from typing import Literal
7
+ from typing import Any , Literal , Self , overload
6
8
7
9
import numpy as np
8
10
from scipy .stats import norm
9
11
10
- from .parameter_config import ParameterConfig
11
-
12
- # parameter_configs = List[FieldParameter, SurfaceParameter, ScalarParameter]
12
+ from ._str_to_bool import str_to_bool
13
+ from . parameter_config import parse_config
14
+ from . parsing import ConfigValidationError , ConfigWarning
13
15
14
16
15
17
@dataclass
@@ -149,10 +151,56 @@ class PolarsData:
149
151
data_set_file : Path
150
152
151
153
154
+ @overload
155
+ def _get_abs_path (file : None ) -> None :
156
+ pass
157
+
158
+
159
+ def get_distribution (name : str , values : list [str ]) -> Any :
160
+ return {
161
+ "NORMAL" : TransNormalSettings (mean = float (values [0 ]), std = float (values [1 ])),
162
+ "UNIFORM" : TransUnifSettings (min = float (values [0 ]), max = float (values [1 ])),
163
+ "TRUNC_NORMAL" : TransTruncNormalSettings (
164
+ mean = float (values [0 ]),
165
+ std = float (values [1 ]),
166
+ min = float (values [2 ]),
167
+ max = float (values [3 ]),
168
+ ),
169
+ "RAW" : TransRawSettings (),
170
+ "CONST" : TransConstSettings (value = float (values [0 ])),
171
+ "DUNIF" : TransDUnifSettings (
172
+ steps = int (values [0 ]), min = float (values [1 ]), max = float (values [2 ])
173
+ ),
174
+ "TRIANGULAR" : TransTriangularSettings (
175
+ min = float (values [0 ]), mode = float (values [1 ]), max = float (values [2 ])
176
+ ),
177
+ "ERRF" : TransErrfSettings (
178
+ min = float (values [0 ]),
179
+ max = float (values [1 ]),
180
+ skew = float (values [2 ]),
181
+ width = float (values [3 ]),
182
+ ),
183
+ "DERRF" : TransDerrfSettings (
184
+ steps = int (values [0 ]),
185
+ min = float (values [1 ]),
186
+ max = float (values [2 ]),
187
+ skew = float (values [3 ]),
188
+ width = float (values [4 ]),
189
+ ),
190
+ }[name ]
191
+
192
+
193
+ class DataSource (StrEnum ):
194
+ DESIGN_MATRIX = "design_matrix"
195
+ SAMPLED = "sampled"
196
+
197
+
152
198
@dataclass
153
- class ScalarParameter (ParameterConfig ):
154
- # name: str
155
- group : str
199
+ class ScalarParameter :
200
+ template_file : str | None
201
+ output_file : str | None
202
+ param_name : str
203
+ group_name : str
156
204
distribution : (
157
205
TransUnifSettings
158
206
| TransDUnifSettings
@@ -164,6 +212,130 @@ class ScalarParameter(ParameterConfig):
164
212
| TransDerrfSettings
165
213
| TransTriangularSettings
166
214
)
167
- active : bool
168
- input_source : Literal ["design_matrix" , "sampled" ]
169
- dataset_file : PolarsData
215
+ # active: bool
216
+ input_source : DataSource
217
+ # dataset_file: PolarsData | None
218
+
219
+ @classmethod
220
+ def from_config_list (cls , gen_kw : list [str ]) -> list [Self ]:
221
+ gen_kw_key = gen_kw [0 ]
222
+
223
+ positional_args , options = parse_config (gen_kw , 4 )
224
+ forward_init = str_to_bool (options .get ("FORWARD_INIT" , "FALSE" ))
225
+ init_file = _get_abs_path (options .get ("INIT_FILES" ))
226
+ update_parameter = str_to_bool (options .get ("UPDATE" , "TRUE" ))
227
+ errors = []
228
+
229
+ if len (positional_args ) == 2 :
230
+ parameter_file = _get_abs_path (positional_args [1 ])
231
+ parameter_file_context = positional_args [1 ]
232
+ template_file = None
233
+ output_file = None
234
+ elif len (positional_args ) == 4 :
235
+ output_file = positional_args [2 ]
236
+ parameter_file = _get_abs_path (positional_args [3 ])
237
+ parameter_file_context = positional_args [3 ]
238
+ template_file = _get_abs_path (positional_args [1 ])
239
+ if not os .path .isfile (template_file ):
240
+ errors .append (
241
+ ConfigValidationError .with_context (
242
+ f"No such template file: { template_file } " , positional_args [1 ]
243
+ )
244
+ )
245
+ elif Path (template_file ).stat ().st_size == 0 :
246
+ token = (
247
+ parameter_file_context .token
248
+ if hasattr (parameter_file_context , "token" )
249
+ else parameter_file_context
250
+ )
251
+ ConfigWarning .deprecation_warn (
252
+ f"The template file for GEN_KW ({ gen_kw_key } ) is empty. If templating is not needed, you "
253
+ f"can use GEN_KW with just the distribution file instead: GEN_KW { gen_kw_key } { token } " ,
254
+ positional_args [1 ],
255
+ )
256
+
257
+ else :
258
+ raise ConfigValidationError (
259
+ f"Unexpected positional arguments: { positional_args } "
260
+ )
261
+ if not os .path .isfile (parameter_file ):
262
+ errors .append (
263
+ ConfigValidationError .with_context (
264
+ f"No such parameter file: { parameter_file } " , parameter_file_context
265
+ )
266
+ )
267
+ elif Path (parameter_file ).stat ().st_size == 0 :
268
+ errors .append (
269
+ ConfigValidationError .with_context (
270
+ f"No parameters specified in { parameter_file } " ,
271
+ parameter_file_context ,
272
+ )
273
+ )
274
+
275
+ if forward_init :
276
+ errors .append (
277
+ ConfigValidationError .with_context (
278
+ "Loading GEN_KW from files created by the forward "
279
+ "model is not supported." ,
280
+ gen_kw ,
281
+ )
282
+ )
283
+
284
+ if init_file :
285
+ errors .append (
286
+ ConfigValidationError .with_context (
287
+ "Loading GEN_KW from init_files is not longer supported!" ,
288
+ gen_kw ,
289
+ )
290
+ )
291
+
292
+ if errors :
293
+ raise ConfigValidationError .from_collected (errors )
294
+
295
+ parameter_configuration : list [Self ] = []
296
+ # transform_function_definitions: list[TransformFunctionDefinition] = []
297
+ with open (parameter_file , encoding = "utf-8" ) as file :
298
+ for line_number , item in enumerate (file ):
299
+ item = item .split ("--" )[0 ] # remove comments
300
+ if item .strip (): # only lines with content
301
+ items = item .split ()
302
+ if len (items ) < 2 :
303
+ errors .append (
304
+ ConfigValidationError .with_context (
305
+ f"Too few values on line { line_number } in parameter file { parameter_file } " ,
306
+ gen_kw ,
307
+ )
308
+ )
309
+ else :
310
+ parameter_configuration .append (
311
+ cls (
312
+ param_name = items [1 ],
313
+ input_source = DataSource .SAMPLED ,
314
+ group_name = gen_kw_key ,
315
+ distribution = get_distribution (items [0 ], items [2 :]),
316
+ template_file = template_file ,
317
+ output_file = output_file ,
318
+ )
319
+ )
320
+
321
+ if errors :
322
+ raise ConfigValidationError .from_collected (errors )
323
+
324
+ if gen_kw_key == "PRED" and update_parameter :
325
+ ConfigWarning .warn (
326
+ "GEN_KW PRED used to hold a special meaning and be "
327
+ "excluded from being updated.\n If the intention was "
328
+ "to exclude this from updates, set UPDATE:FALSE.\n " ,
329
+ gen_kw [0 ],
330
+ )
331
+ return parameter_configuration
332
+
333
+ # return cls(
334
+ # name=gen_kw_key,
335
+ # forward_init=forward_init,
336
+ # template_file=template_file,
337
+ # output_file=output_file,
338
+ # forward_init_file=init_file,
339
+ # transform_function_definitions=transform_function_definitions,
340
+ # update=update_parameter,
341
+ # )
0 commit comments