1
1
import math
2
2
import os
3
3
import warnings
4
+ from collections import defaultdict
4
5
from dataclasses import dataclass
5
6
from enum import StrEnum
6
7
from hashlib import sha256
@@ -357,6 +358,40 @@ def save_parameters(
357
358
def __len__ (self ) -> int :
358
359
return sum (len (params ) for params in self .scalar_params .values ())
359
360
361
+ def read_from_runpath (
362
+ self ,
363
+ run_path : Path ,
364
+ real_nr : int ,
365
+ iteration : int ,
366
+ ) -> None :
367
+ """
368
+ forward_init will not be supported, so None for the moment
369
+ """
370
+ return None
371
+
372
+ def write_to_runpath (
373
+ self , run_path : Path , real_nr : int , ensemble : Ensemble
374
+ ) -> dict [str , dict [str , float ]] | None :
375
+ """
376
+ This function is responsible for converting the parameter
377
+ from the internal ert format to the format the forward model
378
+ expects
379
+ """
380
+ df = ensemble .load_parameters_scalar (real_nr )
381
+ df_transformed = df .filter (pl .col ("type" ) == "transformed" ).drop ("type" )
382
+ transformed_dict = {}
383
+
384
+ for col in df_transformed .columns :
385
+ group_name , param_name = col .split (":" )
386
+ transformed_value = df_transformed .select (pl .col (col )).to_series ()[0 ]
387
+
388
+ # Build the nested dictionary
389
+ if group_name not in transformed_dict :
390
+ transformed_dict [group_name ] = {}
391
+ transformed_dict [group_name ][param_name ] = transformed_value
392
+ # todo log handling when distribution requires it
393
+ return transformed_dict
394
+
360
395
@classmethod
361
396
def from_config_list (cls , gen_kw : list [str ]) -> Self :
362
397
gen_kw_key = gen_kw [0 ]
@@ -434,6 +469,9 @@ def from_config_list(cls, gen_kw: list[str]) -> Self:
434
469
raise ConfigValidationError .from_collected (errors )
435
470
436
471
all_params : dict [DataSource , list [ScalarParameter ]] = {DataSource .SAMPLED : []}
472
+ group_params : defaultdict [str , list [ScalarParameter ]] = defaultdict (
473
+ list [ScalarParameter ]
474
+ )
437
475
with open (parameter_file , encoding = "utf-8" ) as file :
438
476
for line_number , item in enumerate (file ):
439
477
item = item .split ("--" )[0 ] # remove comments
@@ -447,16 +485,16 @@ def from_config_list(cls, gen_kw: list[str]) -> Self:
447
485
)
448
486
)
449
487
else :
450
- all_params [DataSource .SAMPLED ].append (
451
- ScalarParameter (
452
- param_name = items [1 ],
453
- input_source = DataSource .SAMPLED ,
454
- group_name = gen_kw_key ,
455
- distribution = get_distribution (items [0 ], items [2 :]),
456
- template_file = template_file ,
457
- output_file = output_file ,
458
- )
488
+ param = ScalarParameter (
489
+ param_name = items [1 ],
490
+ input_source = DataSource .SAMPLED ,
491
+ group_name = gen_kw_key ,
492
+ distribution = get_distribution (items [0 ], items [2 :]),
493
+ template_file = template_file ,
494
+ output_file = output_file ,
459
495
)
496
+ all_params [DataSource .SAMPLED ].append (param )
497
+ group_params [gen_kw_key ].append (param )
460
498
461
499
if errors :
462
500
raise ConfigValidationError .from_collected (errors )
@@ -468,6 +506,10 @@ def from_config_list(cls, gen_kw: list[str]) -> Self:
468
506
"to exclude this from updates, set UPDATE:FALSE.\n " ,
469
507
gen_kw [0 ],
470
508
)
471
- return ScalarParameters (
472
- forward_init = False , update = True , scalar_params = all_params
509
+ return cls (
510
+ name = "SCALAR_PARAMETERS" ,
511
+ forward_init = False ,
512
+ update = True ,
513
+ scalar_params = all_params ,
514
+ groups = group_params ,
473
515
)
0 commit comments