Skip to content

Commit

Permalink
Fixup gen_data_config
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Sep 13, 2024
1 parent 0c62f1d commit 26b980d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 52 deletions.
90 changes: 50 additions & 40 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import iterative_ensemble_smoother as ies
import numpy as np
import pandas as pd
import polars

Check failure on line 22 in src/ert/analysis/_es_update.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Cannot find implementation or library stub for module named "polars"
import psutil
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
Expand Down Expand Up @@ -153,46 +154,55 @@ def _get_observations_and_responses(
observation_values = []
observation_errors = []
indexes = []
observations = ensemble.experiment.observations
for obs in selected_observations:
observation = observations[obs]
group = observation.attrs["response"]
all_responses = ensemble.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
all_responses = all_responses.reindex(
time=observation.time,
method="nearest",
tolerance="1s",
)
try:
observations_and_responses = observation.merge(all_responses, join="left")
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched index for: "
f"Observation: {obs} attached to response: {group}"
) from e

observation_keys.append([obs] * observations_and_responses["observations"].size)

if group == "summary":
indexes.append(
[
np.datetime_as_string(e, unit="s")
for e in observations_and_responses["time"].data
]
)
else:
indexes.append(
[
f"{e[0]}, {e[1]}"
for e in zip(
list(observations_and_responses["report_step"].data)
* len(observations_and_responses["index"].data),
observations_and_responses["index"].data,
)
]
)

observations_by_type = ensemble.experiment.observations
for response_type in ensemble.experiment.response_info:
observations_for_type = observations_by_type[response_type].filter(
polars.col("observation_key").is_in(selected_observations)
)
responses_for_type = ensemble.load_responses(
response_type, realizations=tuple(iens_active_index)
)
joined = observations_for_type.join(responses_for_type, how="left")

Check failure on line 165 in src/ert/analysis/_es_update.py

View workflow job for this annotation

GitHub Actions / check-style (3.12)

Local variable `joined` is assigned to but never used

#
# observation = None
# # group = observation.attrs["response"]
# all_responses = ensemble.load_responses(group, tuple(iens_active_index))
# if "time" in observation.coords:
# all_responses = all_responses.reindex(
# time=observation.time,
# method="nearest",
# tolerance="1s",
# )
# try:
# observations_and_responses = observation.merge(all_responses, join="left")
# except KeyError as e:
# raise ErtAnalysisError(
# f"Mismatched index for: "
# f"Observation: {obs} attached to response: {group}"
# ) from e
#
# observation_keys.append([obs] * observations_and_responses["observations"].size)
#
# if group == "summary":
# indexes.append(
# [
# np.datetime_as_string(e, unit="s")
# for e in observations_and_responses["time"].data
# ]
# )
# else:
# indexes.append(
# [
# f"{e[0]}, {e[1]}"
# for e in zip(
# list(observations_and_responses["report_step"].data)
# * len(observations_and_responses["index"].data),
# observations_and_responses["index"].data,
# )
# ]
# )
observations_and_responses = None
observation_values.append(
observations_and_responses["observations"].data.ravel()
)
Expand Down
19 changes: 11 additions & 8 deletions src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import polars

Check failure on line 8 in src/ert/config/gen_data_config.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Cannot find implementation or library stub for module named "polars"
import xarray as xr
from typing_extensions import Self

from ert.validation import rangestring_to_list
Expand Down Expand Up @@ -119,9 +118,9 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
data[active_list == 0] = np.nan
return polars.DataFrame(
{
"values": polars.Series(data, dtype=polars.Float32),
"index": np.arange(len(data)),
"report_step": report_step,
"index": np.arange(len(data)),
"values": polars.Series(data, dtype=polars.Float32),
}
)

Expand Down Expand Up @@ -151,16 +150,16 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
except ValueError as err:
errors.append(str(err))

ds_all_report_steps = xr.concat(
datasets_per_report_step, dim="report_step"
).expand_dims(name=[name])
ds_all_report_steps = polars.concat(datasets_per_report_step)
ds_all_report_steps.insert_column(
0, polars.Series("response_key", [name] * len(ds_all_report_steps))
)
datasets_per_name.append(ds_all_report_steps)

if errors:
raise ValueError(f"Error reading GEN_DATA: {self.name}, errors: {errors}")

combined = xr.concat(datasets_per_name, dim="name")
combined.attrs["response"] = "gen_data"
combined = polars.concat(datasets_per_name)
return combined

def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]]:
Expand All @@ -174,5 +173,9 @@ def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]
def response_type(self) -> str:
return "gen_data"

@property
def primary_key(self) -> List[str]:
return ["index", "report_step"]


responses_index.add_response_type(GenDataConfig)
11 changes: 9 additions & 2 deletions src/ert/config/observation_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __len__(self) -> int:

def to_dataset(self, active_list: List[int]) -> polars.DataFrame:
if self.observation_type == EnkfObservationImplementationType.GEN_OBS:
actual_response_key = self.data_key
actual_observation_key = self.observation_key
dataframes = []
for time_step, node in self.observations.items():
if active_list and time_step not in active_list:
Expand All @@ -38,7 +40,8 @@ def to_dataset(self, active_list: List[int]) -> polars.DataFrame:
dataframes.append(
polars.DataFrame(
{
"name": self.data_key,
"response_key": actual_response_key,
"observation_key": actual_observation_key,
"index": node.indices,
"report_step": time_step,
"observations": polars.Series(
Expand All @@ -52,6 +55,8 @@ def to_dataset(self, active_list: List[int]) -> polars.DataFrame:
return combined # type: ignore

Check failure on line 55 in src/ert/config/observation_vector.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Unused "type: ignore" comment
elif self.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
observations = []
actual_response_key = self.observation_key
actual_observation_keys = []
errors = []
dates = list(self.observations.keys())
if active_list:
Expand All @@ -60,12 +65,14 @@ def to_dataset(self, active_list: List[int]) -> polars.DataFrame:
for time_step in dates:
n = self.observations[time_step]
assert isinstance(n, SummaryObservation)
actual_observation_keys.append(n.observation_key)
observations.append(n.value)
errors.append(n.std)

return polars.DataFrame(
{
"name": self.observation_key,
"response_key": actual_response_key,
"observation_key": actual_observation_keys,
"time": dates,
"observations": polars.Series(observations, dtype=polars.Float32),
"std": polars.Series(errors, dtype=polars.Float32),
Expand Down
14 changes: 13 additions & 1 deletion src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,25 @@ def update_parameters(self) -> List[str]:
return [p.name for p in self.parameter_configuration.values() if p.update]

@cached_property
def observations(self) -> Dict[str, xr.Dataset]:
def observations(self) -> Dict[str, polars.DataFrame]:
observations = sorted(self.mount_point.glob("observations/*"))
return {
observation.name: polars.read_parquet(f"{observation}")
for observation in observations
}

@cached_property
def observation_keys(self) -> List[str]:
"""
Gets all \"name\" values for all observations. I.e.,
the summary keyword, the gen_data observation name etc.
"""
keys = []
for df in self.observations.values():
keys.extend(df["observation_key"].unique())

return sorted(keys)

@cached_property
def response_key_to_response_type(self) -> Dict[str, str]:
mapping = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_update_report(
smoother_update(
prior_ens,
posterior_ens,
list(ert_config.observations.keys()),
list(experiment.observation_keys),
ert_config.ensemble_config.parameters,
UpdateSettings(auto_scale_observations=misfit_preprocess),
ESSettings(inversion="subspace"),
Expand Down

0 comments on commit 26b980d

Please sign in to comment.