Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScalarParameters replacing GenKW in parameter config #10095

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
import time
from collections.abc import Callable, Iterable, Sequence
from fnmatch import fnmatch
from typing import (
TYPE_CHECKING,
Generic,
Self,
TypeVar,
)
from typing import TYPE_CHECKING, Generic, Self, TypeVar

import iterative_ensemble_smoother as ies
import numpy as np
Expand All @@ -19,7 +14,13 @@
import scipy
from iterative_ensemble_smoother.experimental import AdaptiveESMDA

from ert.config import ESSettings, GenKwConfig, ObservationGroups, UpdateSettings
from ert.config import (
ESSettings,
GenKwConfig,
ObservationGroups,
ScalarParameters,
UpdateSettings,
)

from . import misfit_preprocessor
from .event import (
Expand All @@ -31,10 +32,7 @@
AnalysisTimeEvent,
DataSection,
)
from .snapshots import (
ObservationAndResponseSnapshot,
SmootherSnapshot,
)
from .snapshots import ObservationAndResponseSnapshot, SmootherSnapshot

if TYPE_CHECKING:
import numpy.typing as npt
Expand Down Expand Up @@ -111,16 +109,26 @@ def _all_parameters(


def _save_param_ensemble_array_to_disk(
ensemble: Ensemble,
source_ensemble: Ensemble,
target_ensemble: Ensemble,
param_ensemble_array: npt.NDArray[np.float64],
param_group: str,
iens_active_index: npt.NDArray[np.int_],
) -> None:
config_node = ensemble.experiment.parameter_configuration[param_group]
for i, realization in enumerate(iens_active_index):
config_node.save_parameters(
ensemble, param_group, realization, param_ensemble_array[:, i]
config_node = target_ensemble.experiment.parameter_configuration[param_group]
if isinstance(config_node, ScalarParameters):
config_node.save_updated_parameters_and_copy_remaining(
source_ensemble,
target_ensemble,
param_group,
iens_active_index,
param_ensemble_array,
)
else:
for i, realization in enumerate(iens_active_index):
config_node.save_parameters(
target_ensemble, param_group, realization, param_ensemble_array[:, i]
)


def _load_param_ensemble_array(
Expand All @@ -129,7 +137,11 @@ def _load_param_ensemble_array(
iens_active_index: npt.NDArray[np.int_],
) -> npt.NDArray[np.float64]:
config_node = ensemble.experiment.parameter_configuration[param_group]
return config_node.load_parameters(ensemble, param_group, iens_active_index)
if isinstance(config_node, ScalarParameters):
return config_node.load_parameters_to_update(ensemble, iens_active_index)
dataset = config_node.load_parameters(ensemble, param_group, iens_active_index)
assert isinstance(dataset, np.ndarray), "dataset is not an numpy array"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we keeping this assertion?

return dataset


def _expand_wildcards(
Expand Down Expand Up @@ -608,9 +620,12 @@ def correlation_callback(
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))
start = time.time()

_save_param_ensemble_array_to_disk(
target_ensemble, param_ensemble_array, param_group, iens_active_index
source_ensemble,
target_ensemble,
param_ensemble_array,
param_group,
iens_active_index,
)
logger.info(
f"Storing data for {param_group} completed in {(time.time() - start) / 60} minutes"
Expand Down
12 changes: 12 additions & 0 deletions src/ert/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
from .parsing.observations_parser import ObservationType
from .queue_config import QueueConfig
from .response_config import InvalidResponseFile, ResponseConfig
from .scalar_parameter import (
SCALAR_PARAMETERS_NAME,
DataSource,
ScalarParameter,
ScalarParameters,
get_distribution,
)
from .summary_config import SummaryConfig
from .summary_observation import SummaryObservation
from .surface_config import SurfaceConfig
Expand All @@ -39,11 +46,13 @@

__all__ = [
"DESIGN_MATRIX_GROUP",
"SCALAR_PARAMETERS_NAME",
"AnalysisConfig",
"AnalysisModule",
"ConfigValidationError",
"ConfigValidationError",
"ConfigWarning",
"DataSource",
"DesignMatrix",
"ESSettings",
"EnkfObs",
Expand All @@ -70,6 +79,8 @@
"QueueConfig",
"QueueSystem",
"ResponseConfig",
"ScalarParameter",
"ScalarParameters",
"SummaryConfig",
"SummaryObservation",
"SurfaceConfig",
Expand All @@ -80,5 +91,6 @@
"WorkflowJob",
"capture_validation",
"field_transform",
"get_distribution",
"lint_file",
]
110 changes: 49 additions & 61 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
import pandas as pd
from pandas.api.types import is_integer_dtype

from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition

from ._option_dict import option_dict
from .parsing import ConfigValidationError, ErrorInfo
from .scalar_parameter import (
DataSource,
ScalarParameter,
ScalarParameters,
TransRawSettings,
)

if TYPE_CHECKING:
from ert.config import ParameterConfig
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove this TYPE_CHECKING


DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"

Expand All @@ -32,7 +36,7 @@ def __post_init__(self) -> None:
(
self.active_realizations,
self.design_matrix_df,
self.parameter_configuration,
self.scalars,
) = self.read_design_matrix()
except (ValueError, AttributeError) as exc:
raise ConfigValidationError.with_context(
Expand Down Expand Up @@ -102,66 +106,54 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None:
except ValueError as exc:
errors.append(ErrorInfo(f"Error when merging design matrices {exc}!"))

for tfd in dm_other.parameter_configuration.transform_function_definitions:
self.parameter_configuration.transform_function_definitions.append(tfd)
for param in dm_other.scalars:
self.scalars.append(param)

if errors:
raise ConfigValidationError.from_collected(errors)

def merge_with_existing_parameters(
self, existing_parameters: list[ParameterConfig]
) -> tuple[list[ParameterConfig], GenKwConfig]:
self, existing_scalars: ScalarParameters
) -> ScalarParameters:
"""
This method merges the design matrix parameters with the existing parameters and
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
GEN_KW group that was dropped will acquire a new name from the design matrix group.
Additionally, the ParameterConfig which is the design matrix group is returned separately.

returns the new list of existing parameters.
Args:
existing_parameters (List[ParameterConfig]): List of existing parameters
existing_scalars (ScalarParameters): existing scalar parameters

Raises:
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group

Returns:
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
ScalarParameters: new set of ScalarParameters
"""

new_param_config: list[ParameterConfig] = []

design_parameter_group = self.parameter_configuration
design_keys = [e.name for e in design_parameter_group.transform_functions]
all_params: list[ScalarParameter] = []

design_group_added = False
for parameter_group in existing_parameters:
if not isinstance(parameter_group, GenKwConfig):
new_param_config += [parameter_group]
overlap_set = set()
for existing_parameter in existing_scalars.scalars:
if existing_parameter.input_source == DataSource.DESIGN_MATRIX:
continue
existing_keys = [e.name for e in parameter_group.transform_functions]
if set(existing_keys) == set(design_keys):
if design_group_added:
raise ConfigValidationError(
"Multiple overlapping groups with design matrix found in existing parameters!\n"
f"{design_parameter_group.name} and {parameter_group.name}"
)

design_parameter_group.name = parameter_group.name
design_parameter_group.template_file = parameter_group.template_file
design_parameter_group.output_file = parameter_group.output_file
design_group_added = True
elif set(design_keys) & set(existing_keys):
raise ConfigValidationError(
"Overlapping parameter names found in design matrix!\n"
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
"\nThey need to match exactly or not at all."
)
else:
new_param_config += [parameter_group]
return new_param_config, design_parameter_group
overlap = False
for parameter_design in self.scalars:
if existing_parameter.param_name == parameter_design.param_name:
parameter_design.group_name = existing_parameter.group_name
parameter_design.template_file = existing_parameter.template_file
parameter_design.output_file = existing_parameter.output_file
all_params.append(parameter_design)
overlap = True
overlap_set.add(existing_parameter.param_name)
break
if not overlap:
all_params.append(existing_parameter)

for parameter_design in self.scalars:
if parameter_design.param_name not in overlap_set:
all_params.append(parameter_design)

return ScalarParameters(scalars=all_params)

def read_design_matrix(
self,
) -> tuple[list[bool], pd.DataFrame, GenKwConfig]:
) -> tuple[list[bool], pd.DataFrame, list[ScalarParameter]]:
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
Expand Down Expand Up @@ -207,29 +199,25 @@ def read_design_matrix(

design_matrix_df = pd.concat([design_matrix_df, default_df], axis=1)

transform_function_definitions: list[TransformFunctionDefinition] = []
scalars: list[ScalarParameter] = []
for parameter in design_matrix_df.columns:
transform_function_definitions.append(
TransformFunctionDefinition(
name=parameter,
param_name="RAW",
values=[],
scalars.append(
ScalarParameter(
param_name=parameter,
group_name=DESIGN_MATRIX_GROUP,
input_source=DataSource.DESIGN_MATRIX,
distribution=TransRawSettings(),
template_file=None,
output_file=None,
update=False,
)
)
parameter_configuration = GenKwConfig(
name=DESIGN_MATRIX_GROUP,
forward_init=False,
template_file=None,
output_file=None,
transform_function_definitions=transform_function_definitions,
update=False,
)

reals = design_matrix_df.index.tolist()
return (
[x in reals for x in range(max(reals) + 1)],
design_matrix_df,
parameter_configuration,
scalars,
)

@staticmethod
Expand Down
Loading
Loading