1
1
import math
2
+ import os
2
3
import warnings
3
4
from dataclasses import dataclass
5
+ from enum import StrEnum
4
6
from pathlib import Path
5
- from typing import Literal
7
+ from typing import Any , Literal , Self , overload
6
8
7
9
import numpy as np
8
10
from scipy .stats import norm
9
11
10
- from .parameter_config import ParameterConfig
11
-
12
- # parameter_configs = List[FieldParameter, SurfaceParameter, ScalarParameter]
12
+ from ._str_to_bool import str_to_bool
13
+ from . parameter_config import parse_config
14
+ from . parsing import ConfigValidationError , ConfigWarning
13
15
14
16
15
17
@dataclass
@@ -149,10 +151,67 @@ class PolarsData:
149
151
data_set_file : Path
150
152
151
153
154
+ @overload
155
+ def _get_abs_path (file : None ) -> None :
156
+ pass
157
+
158
+
159
+ @overload
160
+ def _get_abs_path (file : str ) -> str :
161
+ pass
162
+
163
+
164
+ def _get_abs_path (file : str | None ) -> str | None :
165
+ if file is not None :
166
+ file = os .path .realpath (file )
167
+ return file
168
+
169
+
170
+ def get_distribution (name : str , values : list [str ]) -> Any :
171
+ return {
172
+ "NORMAL" : TransNormalSettings (mean = float (values [0 ]), std = float (values [1 ])),
173
+ "UNIFORM" : TransUnifSettings (min = float (values [0 ]), max = float (values [1 ])),
174
+ "TRUNC_NORMAL" : TransTruncNormalSettings (
175
+ mean = float (values [0 ]),
176
+ std = float (values [1 ]),
177
+ min = float (values [2 ]),
178
+ max = float (values [3 ]),
179
+ ),
180
+ "RAW" : TransRawSettings (),
181
+ "CONST" : TransConstSettings (value = float (values [0 ])),
182
+ "DUNIF" : TransDUnifSettings (
183
+ steps = int (values [0 ]), min = float (values [1 ]), max = float (values [2 ])
184
+ ),
185
+ "TRIANGULAR" : TransTriangularSettings (
186
+ min = float (values [0 ]), mode = float (values [1 ]), max = float (values [2 ])
187
+ ),
188
+ "ERRF" : TransErrfSettings (
189
+ min = float (values [0 ]),
190
+ max = float (values [1 ]),
191
+ skew = float (values [2 ]),
192
+ width = float (values [3 ]),
193
+ ),
194
+ "DERRF" : TransDerrfSettings (
195
+ steps = int (values [0 ]),
196
+ min = float (values [1 ]),
197
+ max = float (values [2 ]),
198
+ skew = float (values [3 ]),
199
+ width = float (values [4 ]),
200
+ ),
201
+ }[name ]
202
+
203
+
204
+ class DataSource (StrEnum ):
205
+ DESIGN_MATRIX = "design_matrix"
206
+ SAMPLED = "sampled"
207
+
208
+
152
209
@dataclass
153
- class ScalarParameter (ParameterConfig ):
154
- # name: str
155
- group : str
210
+ class ScalarParameter :
211
+ template_file : str | None
212
+ output_file : str | None
213
+ param_name : str
214
+ group_name : str
156
215
distribution : (
157
216
TransUnifSettings
158
217
| TransDUnifSettings
@@ -164,6 +223,129 @@ class ScalarParameter(ParameterConfig):
164
223
| TransDerrfSettings
165
224
| TransTriangularSettings
166
225
)
167
- active : bool
168
- input_source : Literal ["design_matrix" , "sampled" ]
169
- dataset_file : PolarsData
226
+ # active: bool
227
+ input_source : DataSource
228
+ # dataset_file: PolarsData | None
229
+
230
+ @classmethod
231
+ def from_config_list (cls , gen_kw : list [str ]) -> list [Self ]:
232
+ gen_kw_key = gen_kw [0 ]
233
+
234
+ positional_args , options = parse_config (gen_kw , 4 )
235
+ forward_init = str_to_bool (options .get ("FORWARD_INIT" , "FALSE" ))
236
+ init_file = _get_abs_path (options .get ("INIT_FILES" ))
237
+ update_parameter = str_to_bool (options .get ("UPDATE" , "TRUE" ))
238
+ errors = []
239
+
240
+ if len (positional_args ) == 2 :
241
+ parameter_file = _get_abs_path (positional_args [1 ])
242
+ parameter_file_context = positional_args [1 ]
243
+ template_file = None
244
+ output_file = None
245
+ elif len (positional_args ) == 4 :
246
+ output_file = positional_args [2 ]
247
+ parameter_file = _get_abs_path (positional_args [3 ])
248
+ parameter_file_context = positional_args [3 ]
249
+ template_file = _get_abs_path (positional_args [1 ])
250
+ if not os .path .isfile (template_file ):
251
+ errors .append (
252
+ ConfigValidationError .with_context (
253
+ f"No such template file: { template_file } " , positional_args [1 ]
254
+ )
255
+ )
256
+ elif Path (template_file ).stat ().st_size == 0 :
257
+ token = (
258
+ parameter_file_context .token
259
+ if hasattr (parameter_file_context , "token" )
260
+ else parameter_file_context
261
+ )
262
+ ConfigWarning .deprecation_warn (
263
+ f"The template file for GEN_KW ({ gen_kw_key } ) is empty. If templating is not needed, you "
264
+ f"can use GEN_KW with just the distribution file instead: GEN_KW { gen_kw_key } { token } " ,
265
+ positional_args [1 ],
266
+ )
267
+
268
+ else :
269
+ raise ConfigValidationError (
270
+ f"Unexpected positional arguments: { positional_args } "
271
+ )
272
+ if not os .path .isfile (parameter_file ):
273
+ errors .append (
274
+ ConfigValidationError .with_context (
275
+ f"No such parameter file: { parameter_file } " , parameter_file_context
276
+ )
277
+ )
278
+ elif Path (parameter_file ).stat ().st_size == 0 :
279
+ errors .append (
280
+ ConfigValidationError .with_context (
281
+ f"No parameters specified in { parameter_file } " ,
282
+ parameter_file_context ,
283
+ )
284
+ )
285
+
286
+ if forward_init :
287
+ errors .append (
288
+ ConfigValidationError .with_context (
289
+ "Loading GEN_KW from files created by the forward "
290
+ "model is not supported." ,
291
+ gen_kw ,
292
+ )
293
+ )
294
+
295
+ if init_file :
296
+ errors .append (
297
+ ConfigValidationError .with_context (
298
+ "Loading GEN_KW from init_files is not longer supported!" ,
299
+ gen_kw ,
300
+ )
301
+ )
302
+
303
+ if errors :
304
+ raise ConfigValidationError .from_collected (errors )
305
+
306
+ parameter_configuration : list [Self ] = []
307
+ with open (parameter_file , encoding = "utf-8" ) as file :
308
+ for line_number , item in enumerate (file ):
309
+ item = item .split ("--" )[0 ] # remove comments
310
+ if item .strip (): # only lines with content
311
+ items = item .split ()
312
+ if len (items ) < 2 :
313
+ errors .append (
314
+ ConfigValidationError .with_context (
315
+ f"Too few values on line { line_number } in parameter file { parameter_file } " ,
316
+ gen_kw ,
317
+ )
318
+ )
319
+ else :
320
+ parameter_configuration .append (
321
+ cls (
322
+ param_name = items [1 ],
323
+ input_source = DataSource .SAMPLED ,
324
+ group_name = gen_kw_key ,
325
+ distribution = get_distribution (items [0 ], items [2 :]),
326
+ template_file = template_file ,
327
+ output_file = output_file ,
328
+ )
329
+ )
330
+
331
+ if errors :
332
+ raise ConfigValidationError .from_collected (errors )
333
+
334
+ if gen_kw_key == "PRED" and update_parameter :
335
+ ConfigWarning .warn (
336
+ "GEN_KW PRED used to hold a special meaning and be "
337
+ "excluded from being updated.\n If the intention was "
338
+ "to exclude this from updates, set UPDATE:FALSE.\n " ,
339
+ gen_kw [0 ],
340
+ )
341
+ return parameter_configuration
342
+
343
+ # return cls(
344
+ # name=gen_kw_key,
345
+ # forward_init=forward_init,
346
+ # template_file=template_file,
347
+ # output_file=output_file,
348
+ # forward_init_file=init_file,
349
+ # transform_function_definitions=transform_function_definitions,
350
+ # update=update_parameter,
351
+ # )
0 commit comments