From 8c9b2d759b4c243a1e0290694d8f1d945d39f199 Mon Sep 17 00:00:00 2001 From: "Jonas G. Drange" Date: Mon, 2 Mar 2020 11:57:14 +0100 Subject: [PATCH] Remove inactive summary observations --- ert_data/loader.py | 41 ++++++++++++++++++++++++++------------- tests/data/conftest.py | 1 + tests/data/test_loader.py | 36 +++++++++++++++++++++++----------- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/ert_data/loader.py b/ert_data/loader.py index 8bf1578d1a1..b49c9a01b95 100644 --- a/ert_data/loader.py +++ b/ert_data/loader.py @@ -82,16 +82,6 @@ def load_block_data(facade, observation_key, case_name): return data -def load_summary_data(facade, observation_key, case_name): - data_key = facade.get_data_key_for_obs_key(observation_key) - return pd.concat( - [ - _add_summary_observations(facade, data_key, case_name), - _add_summary_data(facade, data_key, case_name), - ] - ) - - def _get_block_measured(ensamble_size, block_data): data = pd.DataFrame() for ensamble_nr in range(ensamble_size): @@ -99,16 +89,39 @@ def _get_block_measured(ensamble_size, block_data): return data -def _add_summary_data(facade, data_key, case_name): +def load_summary_data(facade, observation_key, case_name): + data_key = facade.get_data_key_for_obs_key(observation_key) + args = (facade, observation_key, data_key, case_name) + return pd.concat([ + _get_summary_data(*args), + _get_summary_observations(*args).pipe(_remove_inactive_report_steps, *args) + ]) + + +def _get_summary_data(facade, _, data_key, case_name): data = facade.load_all_summary_data(case_name, [data_key]) data = data[data_key].unstack(level=-1) return data.set_index(data.index.values) -def _add_summary_observations(facade, data_key, case_name): +def _get_summary_observations(facade, _, data_key, case_name): data = facade.load_observation_data(case_name, [data_key]).transpose() # The index from SummaryObservationCollector is {data_key} and STD_{data_key}" # to match the other data types this needs to be changed to OBS and STD, hence # the regex. - data = data.set_index(data.index.str.replace(r"\b" + data_key, "OBS", regex=True)) - return data.set_index(data.index.str.replace("_" + data_key, "")) + data = data.set_index( + data.index.str.replace(r"\b" + data_key, "OBS", regex=True) + ) + data = data.set_index(data.index.str.replace("_" + data_key, "")) + return data + + +def _remove_inactive_report_steps(data, facade, observation_key, *args): + # XXX: the data returned from the SummaryObservationCollector is not + # specific to an observation_key, this means that the dataset contains all + # observations on the data_key. Here the extra data is removed. + obs_vector = facade.get_observations()[observation_key] + active_indices = [] + for step in obs_vector.getStepList(): + active_indices.append(step - 1) + return data.iloc[:, active_indices] diff --git a/tests/data/conftest.py b/tests/data/conftest.py index 6cfdfcb311d..f410f740fe3 100644 --- a/tests/data/conftest.py +++ b/tests/data/conftest.py @@ -17,6 +17,7 @@ def facade(): facade.get_impl.return_value = Mock() facade.get_ensemble_size.return_value = 3 facade.get_observations.return_value = {"some_key": obs_mock} + facade.get_data_key_for_obs_key.return_value = "some_key" facade.get_current_case_name.return_value = "test_case" diff --git a/tests/data/test_loader.py b/tests/data/test_loader.py index c8e0cf6b5c4..2908222c15b 100644 --- a/tests/data/test_loader.py +++ b/tests/data/test_loader.py @@ -5,9 +5,9 @@ import pytest if sys.version_info >= (3, 3): - from unittest.mock import Mock, MagicMock + from unittest.mock import Mock, MagicMock, ANY else: - from mock import Mock, MagicMock + from mock import Mock, MagicMock, ANY def create_expected_data(): @@ -90,19 +90,33 @@ def test_load_block_data(facade, monkeypatch): @pytest.mark.usefixtures("facade") def test_load_summary_data(facade, monkeypatch): - mocked_get_summary_obs = Mock(return_value=create_summary_get_observations()) - mocked_get_summary_measured = Mock( + mocked_get_summary_observations = Mock(return_value=pd.DataFrame()) + mocked_get_summary_data = Mock(return_value=create_summary_get_observations()) + mocked_remove_inactive_report_steps = Mock( return_value=pd.DataFrame(data=[[10.0, 10.0, 10.0, 10.0]]) ) - monkeypatch.setattr(loader, "_add_summary_observations", mocked_get_summary_obs) - monkeypatch.setattr(loader, "_add_summary_data", mocked_get_summary_measured) + monkeypatch.setattr( + loader, "_get_summary_observations", mocked_get_summary_observations + ) + monkeypatch.setattr(loader, "_get_summary_data", mocked_get_summary_data) + monkeypatch.setattr( + loader, "_remove_inactive_report_steps", mocked_remove_inactive_report_steps + ) + + data_key = "some_key" + observation_key = facade.get_observations()[data_key].getDataKey() + case_name = "a_random_name" - result = loader.load_summary_data(facade, "some_key", "a_random_name") - mocked_get_summary_obs.assert_called_once_with( - facade, facade.get_data_key_for_obs_key("some_key"), "a_random_name" + result = loader.load_summary_data(facade, observation_key, case_name) + + mocked_get_summary_observations.assert_called_once_with( + facade, observation_key, data_key, case_name + ) + mocked_get_summary_data.assert_called_once_with( + facade, observation_key, data_key, case_name ) - mocked_get_summary_measured.assert_called_once_with( - facade, facade.get_data_key_for_obs_key("some_key"), "a_random_name" + mocked_remove_inactive_report_steps.assert_called_once_with( + ANY, facade, observation_key, data_key, case_name ) assert result.equals(create_expected_data())