Skip to content

Commit

Permalink
Export re-scaled controls
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Jan 31, 2025
1 parent 6e0574b commit ad8013b
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 54 deletions.
31 changes: 27 additions & 4 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,12 @@ def _init_batch_data(
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
prefix: str = "",
) -> dict[int, dict[str, Any]]:
def _add_controls(
controls_config: list[ControlConfig], values: NDArray[np.float64]
controls_config: list[ControlConfig],
values: NDArray[np.float64],
prefix: str = "",
) -> dict[str, Any]:
batch_data_item: dict[str, Any] = {}
value_list = values.tolist()
Expand All @@ -344,13 +347,28 @@ def _add_controls(
else:
variable_value = value_list.pop(0)
control_dict[variable.name] = variable_value
batch_data_item[control.name] = control_dict
batch_data_item[prefix + control.name] = control_dict
return batch_data_item

def _add_controls_with_rescaling(
controls_config: list[ControlConfig], values: NDArray[np.float64]
) -> dict[str, Any]:
batch_data_item = _add_controls(controls_config, values)
if self._ropt_transforms.variables is not None:
rescaled_item = _add_controls(
controls_config,
self._ropt_transforms.variables.backward(values),
prefix="rescaled-",
)
batch_data_item.update(rescaled_item)
return batch_data_item

active = evaluator_context.active
realizations = evaluator_context.realizations
return {
idx: _add_controls(self._everest_config.controls, control_values[idx, :])
idx: _add_controls_with_rescaling(
self._everest_config.controls, control_values[idx, :]
)
for idx in range(control_values.shape[0])
if (
idx not in cached_results
Expand Down Expand Up @@ -392,7 +410,12 @@ def _check_suffix(
f"Key {key} has suffixes, a suffix must be specified"
)

if set(controls.keys()) != set(self._everest_config.control_names):
control_names = set(self._everest_config.control_names)
if self._ropt_transforms.variables is not None:
control_names |= {
"rescaled-" + name for name in self._everest_config.control_names
}
if set(controls.keys()) != control_names:
err_msg = "Mismatch between initialized and provided control names."
raise KeyError(err_msg)

Expand Down
6 changes: 0 additions & 6 deletions src/everest/config/control_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,6 @@ def ropt_perturbation_type(self) -> PerturbationType:
def ropt_control_type(self) -> VariableType:
return VariableType[self.control_type.upper()]

@property
def has_auto_scale(self) -> bool:
return self.auto_scale or any(
variable.auto_scale for variable in self.variables
)

@model_validator(mode="after")
def validate_variables(self) -> Self:
if self.variables is None:
Expand Down
20 changes: 1 addition & 19 deletions src/everest/config/control_variable_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ropt.enums import VariableType

from .sampler_config import SamplerConfig
from .validation_utils import no_dots_in_string, valid_range
from .validation_utils import no_dots_in_string


class _ControlVariable(BaseModel):
Expand All @@ -34,24 +34,6 @@ class _ControlVariable(BaseModel):
initial value.
""",
)
auto_scale: bool | None = Field(
default=None,
description="""
Can be set to true to re-scale variable from the range
defined by [min, max] to the range defined by scaled_range (default [0, 1])
""",
)
scaled_range: Annotated[tuple[float, float] | None, AfterValidator(valid_range)] = (
Field(
default=None,
description="""
Can be used to set the range of the variable values
after scaling (default = [0, 1]).
This option has no effect if auto_scale is not set.
""",
)
)
min: float | None = Field(
default=None,
description="""
Expand Down
2 changes: 0 additions & 2 deletions src/everest/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def _add_variable(
for key in [
"control_type",
"enabled",
"auto_scale",
"scaled_range",
"min",
"max",
"perturbation_magnitude",
Expand Down
2 changes: 1 addition & 1 deletion src/everest/optimizer/everest2ropt.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def everest2ropt(

def get_ropt_transforms(ever_config: EverestConfig) -> Transforms:
controls = FlattenedControls(ever_config.controls)
if any(item is not None for item in controls.auto_scales):
if any(controls.auto_scales):
variable_scaler = ControlScaler(
controls.lower_bounds,
controls.upper_bounds,
Expand Down
6 changes: 6 additions & 0 deletions src/everest/simulator/everest_to_ert.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,5 +526,11 @@ def _get_variables(
input_keys=_get_variables(control.variables),
output_file=control.name + ".json",
)
if control.auto_scale:
ens_config.parameter_configs["rescaled-" + control.name] = ExtParamConfig(
name="rescaled-" + control.name,
input_keys=_get_variables(control.variables),
output_file="rescaled-" + control.name + ".json",
)

return ert_config
17 changes: 10 additions & 7 deletions test-data/everest/math_func/jobs/distance3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import json
import sys
from pathlib import Path


def compute_distance_squared(p, q):
Expand All @@ -24,23 +25,25 @@ def main(argv):
arg_parser.add_argument("--target-file", type=str)
arg_parser.add_argument("--target", nargs=3, type=float)
arg_parser.add_argument("--out", type=str)
arg_parser.add_argument("--scaling", nargs=4, type=float)
arg_parser.add_argument("--realization", type=float)
options, _ = arg_parser.parse_known_args(args=argv)

point = options.point if options.point else read_point(options.point_file)
point = (
options.point
if options.point
else read_point(
"rescaled-" + options.point_file
if Path("rescaled-" + options.point_file).exists()
else options.point_file
)
)
if len(point) != 3:
raise RuntimeError("Failed parsing point")

target = options.target if options.target else read_point(options.target_file)
if len(target) != 3:
raise RuntimeError("Failed parsing target")

if options.scaling is not None:
min_range, max_range, target_min, target_max = options.scaling
point = [(p - target_min) / (target_max - target_min) for p in point]
point = [p * (max_range - min_range) + min_range for p in point]

value = compute_distance_squared(point, target)
# If any realizations with an index > 0 are passed we make those incorrect
# by taking the negative value. This used by test_cvar.py.
Expand Down
1 change: 0 additions & 1 deletion tests/everest/test_math_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def test_math_func_auto_scaled_controls(
{"weights": {"point.x": 1.0, "point.y": 1.0}, "upper_bound": 0.5}
],
}
config_dict["forward_model"][0] += " --scaling -1 1 0.3 0.7"
config = EverestConfig.model_validate(config_dict)

# Act
Expand Down
14 changes: 0 additions & 14 deletions tests/everest/test_ropt_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,6 @@ def test_everest2ropt_controls_auto_scale():
assert numpy.allclose(ropt_config.variables.upper_bounds, 0.7)


def test_everest2ropt_variables_auto_scale():
config = EverestConfig.load_file(os.path.join(_CONFIG_DIR, _CONFIG_FILE))
controls = config.controls
controls[0].variables[1].auto_scale = True
controls[0].variables[1].scaled_range = [0.3, 0.7]
ropt_config = everest2ropt(config, transforms=get_ropt_transforms(config))
assert ropt_config.variables.lower_bounds[0] == 0.0
assert ropt_config.variables.upper_bounds[0] == 0.1
assert ropt_config.variables.lower_bounds[1] == 0.3
assert ropt_config.variables.upper_bounds[1] == 0.7
assert numpy.allclose(ropt_config.variables.lower_bounds[2:], 0.0)
assert numpy.allclose(ropt_config.variables.upper_bounds[2:], 0.1)


def test_everest2ropt_controls_input_constraint():
config = EverestConfig.load_file(
os.path.join(_CONFIG_DIR, "config_input_constraints.yml")
Expand Down

0 comments on commit ad8013b

Please sign in to comment.