Skip to content

Commit 80c5e43

Browse files
committed
Add write_to_runpath
1 parent 0893983 commit 80c5e43

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

src/ert/config/parameter_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def read_from_runpath(
7171
run_path: Path,
7272
real_nr: int,
7373
iteration: int,
74-
) -> xr.Dataset:
74+
) -> xr.Dataset | None:
7575
"""
7676
This function is responsible for converting the parameter
7777
from the forward model to the internal ert format

src/ert/config/scalar_parameter.py

+53-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import os
33
import warnings
4+
from collections import defaultdict
45
from dataclasses import dataclass
56
from enum import StrEnum
67
from hashlib import sha256
@@ -357,6 +358,40 @@ def save_parameters(
357358
def __len__(self) -> int:
358359
return sum(len(params) for params in self.scalar_params.values())
359360

361+
def read_from_runpath(
362+
self,
363+
run_path: Path,
364+
real_nr: int,
365+
iteration: int,
366+
) -> None:
367+
"""
368+
forward_init will not be supported, so None for the moment
369+
"""
370+
return None
371+
372+
def write_to_runpath(
373+
self, run_path: Path, real_nr: int, ensemble: Ensemble
374+
) -> dict[str, dict[str, float]] | None:
375+
"""
376+
This function is responsible for converting the parameter
377+
from the internal ert format to the format the forward model
378+
expects
379+
"""
380+
df = ensemble.load_parameters_scalar(real_nr)
381+
df_transformed = df.filter(pl.col("type") == "transformed").drop("type")
382+
transformed_dict = {}
383+
384+
for col in df_transformed.columns:
385+
group_name, param_name = col.split(":")
386+
transformed_value = df_transformed.select(pl.col(col)).to_series()[0]
387+
388+
# Build the nested dictionary
389+
if group_name not in transformed_dict:
390+
transformed_dict[group_name] = {}
391+
transformed_dict[group_name][param_name] = transformed_value
392+
# todo log handling when distribution requires it
393+
return transformed_dict
394+
360395
@classmethod
361396
def from_config_list(cls, gen_kw: list[str]) -> Self:
362397
gen_kw_key = gen_kw[0]
@@ -434,6 +469,9 @@ def from_config_list(cls, gen_kw: list[str]) -> Self:
434469
raise ConfigValidationError.from_collected(errors)
435470

436471
all_params: dict[DataSource, list[ScalarParameter]] = {DataSource.SAMPLED: []}
472+
group_params: defaultdict[str, list[ScalarParameter]] = defaultdict(
473+
list[ScalarParameter]
474+
)
437475
with open(parameter_file, encoding="utf-8") as file:
438476
for line_number, item in enumerate(file):
439477
item = item.split("--")[0] # remove comments
@@ -447,16 +485,16 @@ def from_config_list(cls, gen_kw: list[str]) -> Self:
447485
)
448486
)
449487
else:
450-
all_params[DataSource.SAMPLED].append(
451-
ScalarParameter(
452-
param_name=items[1],
453-
input_source=DataSource.SAMPLED,
454-
group_name=gen_kw_key,
455-
distribution=get_distribution(items[0], items[2:]),
456-
template_file=template_file,
457-
output_file=output_file,
458-
)
488+
param = ScalarParameter(
489+
param_name=items[1],
490+
input_source=DataSource.SAMPLED,
491+
group_name=gen_kw_key,
492+
distribution=get_distribution(items[0], items[2:]),
493+
template_file=template_file,
494+
output_file=output_file,
459495
)
496+
all_params[DataSource.SAMPLED].append(param)
497+
group_params[gen_kw_key].append(param)
460498

461499
if errors:
462500
raise ConfigValidationError.from_collected(errors)
@@ -468,6 +506,10 @@ def from_config_list(cls, gen_kw: list[str]) -> Self:
468506
"to exclude this from updates, set UPDATE:FALSE.\n",
469507
gen_kw[0],
470508
)
471-
return ScalarParameters(
472-
forward_init=False, update=True, scalar_params=all_params
509+
return cls(
510+
name="SCALAR_PARAMETERS",
511+
forward_init=False,
512+
update=True,
513+
scalar_params=all_params,
514+
groups=group_params,
473515
)

0 commit comments

Comments
 (0)