-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ScalarParameters replacing GenKW in parameter config #10095
Draft
xjules
wants to merge
6
commits into
equinor:main
Choose a base branch
from
xjules:draft_paramcfg
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4a15c2a
Add ScalarParameters config
xjules 44de330
Fix empty scalars
xjules 13e5a63
Update export to xr.Dataset
xjules d371a4d
Update Scalars for new defaults
xjules e62c5a4
ConfigValidationError on init_files for GEN_KW
xjules 01bfca2
Comment out tests for gen_kw forward_init
xjules File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,17 @@ | |
import pandas as pd | ||
from pandas.api.types import is_integer_dtype | ||
|
||
from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition | ||
|
||
from ._option_dict import option_dict | ||
from .parsing import ConfigValidationError, ErrorInfo | ||
from .scalar_parameter import ( | ||
DataSource, | ||
ScalarParameter, | ||
ScalarParameters, | ||
TransRawSettings, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from ert.config import ParameterConfig | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can remove this TYPE_CHECKING |
||
|
||
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX" | ||
|
||
|
@@ -32,7 +36,7 @@ def __post_init__(self) -> None: | |
( | ||
self.active_realizations, | ||
self.design_matrix_df, | ||
self.parameter_configuration, | ||
self.scalars, | ||
) = self.read_design_matrix() | ||
except (ValueError, AttributeError) as exc: | ||
raise ConfigValidationError.with_context( | ||
|
@@ -102,66 +106,54 @@ def merge_with_other(self, dm_other: DesignMatrix) -> None: | |
except ValueError as exc: | ||
errors.append(ErrorInfo(f"Error when merging design matrices {exc}!")) | ||
|
||
for tfd in dm_other.parameter_configuration.transform_function_definitions: | ||
self.parameter_configuration.transform_function_definitions.append(tfd) | ||
for param in dm_other.scalars: | ||
self.scalars.append(param) | ||
|
||
if errors: | ||
raise ConfigValidationError.from_collected(errors) | ||
|
||
def merge_with_existing_parameters( | ||
self, existing_parameters: list[ParameterConfig] | ||
) -> tuple[list[ParameterConfig], GenKwConfig]: | ||
self, existing_scalars: ScalarParameters | ||
) -> ScalarParameters: | ||
""" | ||
This method merges the design matrix parameters with the existing parameters and | ||
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group. | ||
GEN_KW group that was dropped will acquire a new name from the design matrix group. | ||
Additionally, the ParameterConfig which is the design matrix group is returned separately. | ||
|
||
returns the new list of existing parameters. | ||
Args: | ||
existing_parameters (List[ParameterConfig]): List of existing parameters | ||
existing_scalars (ScalarParameters): existing scalar parameters | ||
|
||
Raises: | ||
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group | ||
|
||
Returns: | ||
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group | ||
ScalarParameters: new set of ScalarParameters | ||
""" | ||
|
||
new_param_config: list[ParameterConfig] = [] | ||
|
||
design_parameter_group = self.parameter_configuration | ||
design_keys = [e.name for e in design_parameter_group.transform_functions] | ||
all_params: list[ScalarParameter] = [] | ||
|
||
design_group_added = False | ||
for parameter_group in existing_parameters: | ||
if not isinstance(parameter_group, GenKwConfig): | ||
new_param_config += [parameter_group] | ||
overlap_set = set() | ||
for existing_parameter in existing_scalars.scalars: | ||
if existing_parameter.input_source == DataSource.DESIGN_MATRIX: | ||
continue | ||
existing_keys = [e.name for e in parameter_group.transform_functions] | ||
if set(existing_keys) == set(design_keys): | ||
if design_group_added: | ||
raise ConfigValidationError( | ||
"Multiple overlapping groups with design matrix found in existing parameters!\n" | ||
f"{design_parameter_group.name} and {parameter_group.name}" | ||
) | ||
|
||
design_parameter_group.name = parameter_group.name | ||
design_parameter_group.template_file = parameter_group.template_file | ||
design_parameter_group.output_file = parameter_group.output_file | ||
design_group_added = True | ||
elif set(design_keys) & set(existing_keys): | ||
raise ConfigValidationError( | ||
"Overlapping parameter names found in design matrix!\n" | ||
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}" | ||
"\nThey need to match exactly or not at all." | ||
) | ||
else: | ||
new_param_config += [parameter_group] | ||
return new_param_config, design_parameter_group | ||
overlap = False | ||
for parameter_design in self.scalars: | ||
if existing_parameter.param_name == parameter_design.param_name: | ||
parameter_design.group_name = existing_parameter.group_name | ||
parameter_design.template_file = existing_parameter.template_file | ||
parameter_design.output_file = existing_parameter.output_file | ||
all_params.append(parameter_design) | ||
overlap = True | ||
overlap_set.add(existing_parameter.param_name) | ||
break | ||
if not overlap: | ||
all_params.append(existing_parameter) | ||
|
||
for parameter_design in self.scalars: | ||
if parameter_design.param_name not in overlap_set: | ||
all_params.append(parameter_design) | ||
|
||
return ScalarParameters(scalars=all_params) | ||
|
||
def read_design_matrix( | ||
self, | ||
) -> tuple[list[bool], pd.DataFrame, GenKwConfig]: | ||
) -> tuple[list[bool], pd.DataFrame, list[ScalarParameter]]: | ||
# Read the parameter names (first row) as strings to prevent pandas from modifying them. | ||
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet. | ||
# By doing this, we can properly validate variable names, including detecting duplicates or missing names. | ||
|
@@ -207,29 +199,25 @@ def read_design_matrix( | |
|
||
design_matrix_df = pd.concat([design_matrix_df, default_df], axis=1) | ||
|
||
transform_function_definitions: list[TransformFunctionDefinition] = [] | ||
scalars: list[ScalarParameter] = [] | ||
for parameter in design_matrix_df.columns: | ||
transform_function_definitions.append( | ||
TransformFunctionDefinition( | ||
name=parameter, | ||
param_name="RAW", | ||
values=[], | ||
scalars.append( | ||
ScalarParameter( | ||
param_name=parameter, | ||
group_name=DESIGN_MATRIX_GROUP, | ||
input_source=DataSource.DESIGN_MATRIX, | ||
distribution=TransRawSettings(), | ||
template_file=None, | ||
output_file=None, | ||
update=False, | ||
) | ||
) | ||
parameter_configuration = GenKwConfig( | ||
name=DESIGN_MATRIX_GROUP, | ||
forward_init=False, | ||
template_file=None, | ||
output_file=None, | ||
transform_function_definitions=transform_function_definitions, | ||
update=False, | ||
) | ||
|
||
reals = design_matrix_df.index.tolist() | ||
return ( | ||
[x in reals for x in range(max(reals) + 1)], | ||
design_matrix_df, | ||
parameter_configuration, | ||
scalars, | ||
) | ||
|
||
@staticmethod | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we keeping this assertion?