Skip to content

Commit

Permalink
Refactored PlotWindow code to access data through an api (PlotApi).
Browse files Browse the repository at this point in the history
  • Loading branch information
mortalisk committed Mar 10, 2020
1 parent 69662eb commit a035f98
Show file tree
Hide file tree
Showing 32 changed files with 911 additions and 764 deletions.
29 changes: 16 additions & 13 deletions ert_data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def data_loader_factory(observation_type):
raise TypeError("Unknown observation type: {}".format(observation_type))


def load_general_data(facade, observation_key, case_name):
def load_general_data(facade, observation_key, case_name, include_data=True):
obs_vector = facade.get_observations()[observation_key]
data_key = obs_vector.getDataKey()

Expand All @@ -26,7 +26,6 @@ def load_general_data(facade, observation_key, case_name):
for time_step in obs_vector.getStepList().asList():
# Fetch, then transpose the simulation data in order to make it
# conform with the GenObservation data structure.
gen_data = facade.load_gen_data(case_name, data_key, time_step).T

# Observations and its standard deviation are a subset of the simulation data.
# The index_list refers to indices in the simulation data. In order to
Expand All @@ -48,12 +47,14 @@ def load_general_data(facade, observation_key, case_name):
)
)
.append(pd.DataFrame([node.get_std()], columns=index_list, index=["STD"]))
.append(gen_data)
)
if include_data:
gen_data = facade.load_gen_data(case_name, data_key, time_step).T
data = data.append(gen_data)
return data


def load_block_data(facade, observation_key, case_name):
def load_block_data(facade, observation_key, case_name, include_data=True):
"""
load_block_data is a part of the data_loader_factory, and the other
methods returned by this factory, require case_name, so it is accepted
Expand All @@ -64,8 +65,6 @@ def load_block_data(facade, observation_key, case_name):

data = pd.DataFrame()
for report_step in obs_vector.getStepList().asList():

block_data = loader.load(facade.get_current_fs(), report_step)
obs_block = loader.getBlockObservation(report_step)

data = (
Expand All @@ -77,10 +76,13 @@ def load_block_data(facade, observation_key, case_name):
.append(
pd.DataFrame([[obs_block.getStd(i) for i in obs_block]], index=["STD"])
)
.append(_get_block_measured(facade.get_ensemble_size(), block_data))
)
return data

if include_data:
block_data = loader.load(facade.get_current_fs(), report_step)
data = data.append(_get_block_measured(facade.get_ensemble_size(), block_data))

return data

def _get_block_measured(ensamble_size, block_data):
data = pd.DataFrame()
Expand All @@ -89,13 +91,14 @@ def _get_block_measured(ensamble_size, block_data):
return data


def load_summary_data(facade, observation_key, case_name):
def load_summary_data(facade, observation_key, case_name, include_data=True):
data_key = facade.get_data_key_for_obs_key(observation_key)
args = (facade, observation_key, data_key, case_name)
return pd.concat([
_get_summary_data(*args),
_get_summary_observations(*args).pipe(_remove_inactive_report_steps, *args)
])
data = []
if include_data:
data.append(_get_summary_data(*args))
data.append(_get_summary_observations(*args).pipe(_remove_inactive_report_steps, *args))
return pd.concat(data)


def _get_summary_data(facade, _, data_key, case_name):
Expand Down
1 change: 0 additions & 1 deletion ert_gui/plottery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
except:
pass

from .plot_data_gatherer import PlotDataGatherer
from .plot_style import PlotStyle
from .plot_limits import PlotLimits
from .plot_config import PlotConfig
Expand Down
21 changes: 5 additions & 16 deletions ert_gui/plottery/plot_config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,23 @@
class PlotConfigFactory(object):

@classmethod
def createPlotConfigForKey(cls, ert, key):
def createPlotConfigForKey(cls, key_def):
"""
@type ert: res.enkf.enkf_main.EnKFMain
@param key: str
@param key_def: dict with definition of a key
@return: PlotConfig
"""
plot_config = PlotConfig(plot_settings=None , title = key)
return PlotConfigFactory.updatePlotConfigForKey(ert, key, plot_config)
plot_config = PlotConfig(plot_settings=None , title = key_def["key"])


@classmethod
def updatePlotConfigForKey(cls, ert, key, plot_config):
"""
@type ert: res.enkf.enkf_main.EnKFMain
@param key: str
@return: PlotConfig
"""
key_manager = ert.getKeyManager()
# The styling of statistics changes based on the nature of the data
if key_manager.isSummaryKey(key) or key_manager.isGenDataKey(key):
if key_def["dimensionality"] == 2:
mean_style = plot_config.getStatisticsStyle("mean")
mean_style.line_style = "-"
plot_config.setStatisticsStyle("mean", mean_style)

p10p90_style = plot_config.getStatisticsStyle("p10-p90")
p10p90_style.line_style = "--"
plot_config.setStatisticsStyle("p10-p90", p10p90_style)
else:
elif key_def["dimensionality"] == 1:
mean_style = plot_config.getStatisticsStyle("mean")
mean_style.line_style = "-"
mean_style.marker = "o"
Expand Down
19 changes: 2 additions & 17 deletions ert_gui/plottery/plot_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .plot_config import PlotConfig
from .plot_data_gatherer import PlotDataGatherer

class PlotContext(object):
UNKNOWN_AXIS = None
Expand All @@ -11,31 +10,21 @@ class PlotContext(object):
DEPTH_AXIS = "DEPTH"
AXIS_TYPES = [UNKNOWN_AXIS, COUNT_AXIS, DATE_AXIS, DENSITY_AXIS, DEPTH_AXIS, INDEX_AXIS, VALUE_AXIS]

def __init__(self, ert, figure, plot_config, cases, key, data_gatherer):
def __init__(self, plot_config, cases, key):
super(PlotContext, self).__init__()
self._data_gatherer = data_gatherer
self._key = key
self._cases = cases
self._figure = figure
self._ert = ert
self._plot_config = plot_config
self.refcase_data = None

self._date_support_active = True
self._x_axis = None
self._y_axis = None

def figure(self):
""" :rtype: matplotlib.figure.Figure"""
return self._figure

def plotConfig(self):
""" :rtype: PlotConfig """
return self._plot_config

def ert(self):
""" :rtype: res.enkf.EnKFMain"""
return self._ert

def cases(self):
""" :rtype: list of str """
return self._cases
Expand All @@ -44,10 +33,6 @@ def key(self):
""" :rtype: str """
return self._key

def dataGatherer(self):
""" :rtype: PlotDataGatherer """
return self._data_gatherer

def deactivateDateSupport(self):
self._date_support_active = False

Expand Down
159 changes: 0 additions & 159 deletions ert_gui/plottery/plot_data_gatherer.py

This file was deleted.

14 changes: 0 additions & 14 deletions ert_gui/plottery/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +0,0 @@
import os
import matplotlib

from .histogram import plotHistogram
from .gaussian_kde import plotGaussianKDE

from .refcase import plotRefcase
from .history import plotHistory
from .observations import plotObservations

from .ensemble import plotEnsemble
from .statistics import plotStatistics
from .distribution import plotDistribution
from .ccsp import plotCrossCaseStatistics
19 changes: 13 additions & 6 deletions ert_gui/plottery/plots/ccsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from .plot_tools import PlotTools
import pandas as pd

def plotCrossCaseStatistics(plot_context):

class CrossCaseStatisticsPlot(object):

def __init__(self):
self.dimensionality = 1

def plot(self, figure, plot_context, case_to_data_map, _observation_data):
plotCrossCaseStatistics(figure,plot_context, case_to_data_map, _observation_data)

def plotCrossCaseStatistics(figure, plot_context, case_to_data_map, _observation_data):
""" @type plot_context: ert_gui.plottery.PlotContext """
ert = plot_context.ert()
key = plot_context.key()
config = plot_context.plotConfig()
axes = plot_context.figure().add_subplot(111)
axes = figure.add_subplot(111)
""":type: matplotlib.axes.Axes """

plot_context.deactivateDateSupport()
Expand All @@ -33,9 +41,8 @@ def plotCrossCaseStatistics(plot_context):
"p67": {},
"p90": {}
}
for case_index, case in enumerate(case_list):
for case_index, (case, data) in enumerate(case_to_data_map.items()):
case_indexes.append(case_index)
data = plot_context.dataGatherer().gatherData(ert, case, key)
std_dev_factor = config.getStandardDeviationFactor()

if not data.empty:
Expand Down Expand Up @@ -68,7 +75,7 @@ def plotCrossCaseStatistics(plot_context):

axes.set_xticklabels([""] + case_list + [""], rotation=rotation)

PlotTools.finalizePlot(plot_context, axes, default_x_label="Case", default_y_label="Value")
PlotTools.finalizePlot(plot_context, figure, axes, default_x_label="Case", default_y_label="Value")


def _addStatisticsLegends(plot_config):
Expand Down
Loading

0 comments on commit a035f98

Please sign in to comment.