diff --git a/ert_data/loader.py b/ert_data/loader.py index b49c9a01b95..525e48b643e 100644 --- a/ert_data/loader.py +++ b/ert_data/loader.py @@ -17,7 +17,7 @@ def data_loader_factory(observation_type): raise TypeError("Unknown observation type: {}".format(observation_type)) -def load_general_data(facade, observation_key, case_name): +def load_general_data(facade, observation_key, case_name, include_data=True): obs_vector = facade.get_observations()[observation_key] data_key = obs_vector.getDataKey() @@ -26,7 +26,6 @@ def load_general_data(facade, observation_key, case_name): for time_step in obs_vector.getStepList().asList(): # Fetch, then transpose the simulation data in order to make it # conform with the GenObservation data structure. - gen_data = facade.load_gen_data(case_name, data_key, time_step).T # Observations and its standard deviation are a subset of the simulation data. # The index_list refers to indices in the simulation data. In order to @@ -48,12 +47,14 @@ def load_general_data(facade, observation_key, case_name): ) ) .append(pd.DataFrame([node.get_std()], columns=index_list, index=["STD"])) - .append(gen_data) ) + if include_data: + gen_data = facade.load_gen_data(case_name, data_key, time_step).T + data = data.append(gen_data) return data -def load_block_data(facade, observation_key, case_name): +def load_block_data(facade, observation_key, case_name, include_data=True): """ load_block_data is a part of the data_loader_factory, and the other methods returned by this factory, require case_name, so it is accepted @@ -64,8 +65,6 @@ def load_block_data(facade, observation_key, case_name): data = pd.DataFrame() for report_step in obs_vector.getStepList().asList(): - - block_data = loader.load(facade.get_current_fs(), report_step) obs_block = loader.getBlockObservation(report_step) data = ( @@ -77,10 +76,13 @@ def load_block_data(facade, observation_key, case_name): .append( pd.DataFrame([[obs_block.getStd(i) for i in obs_block]], index=["STD"]) ) - .append(_get_block_measured(facade.get_ensemble_size(), block_data)) ) - return data + if include_data: + block_data = loader.load(facade.get_current_fs(), report_step) + data = data.append(_get_block_measured(facade.get_ensemble_size(), block_data)) + + return data def _get_block_measured(ensamble_size, block_data): data = pd.DataFrame() @@ -89,13 +91,14 @@ def _get_block_measured(ensamble_size, block_data): return data -def load_summary_data(facade, observation_key, case_name): +def load_summary_data(facade, observation_key, case_name, include_data=True): data_key = facade.get_data_key_for_obs_key(observation_key) args = (facade, observation_key, data_key, case_name) - return pd.concat([ - _get_summary_data(*args), - _get_summary_observations(*args).pipe(_remove_inactive_report_steps, *args) - ]) + data = [] + if include_data: + data.append(_get_summary_data(*args)) + data.append(_get_summary_observations(*args).pipe(_remove_inactive_report_steps, *args)) + return pd.concat(data) def _get_summary_data(facade, _, data_key, case_name): diff --git a/ert_gui/plottery/__init__.py b/ert_gui/plottery/__init__.py index 2203e8dd108..cea902b7e4d 100644 --- a/ert_gui/plottery/__init__.py +++ b/ert_gui/plottery/__init__.py @@ -7,7 +7,6 @@ except: pass -from .plot_data_gatherer import PlotDataGatherer from .plot_style import PlotStyle from .plot_limits import PlotLimits from .plot_config import PlotConfig diff --git a/ert_gui/plottery/plot_config_factory.py b/ert_gui/plottery/plot_config_factory.py index 8a2d6a25469..4b804b4d3ea 100644 --- a/ert_gui/plottery/plot_config_factory.py +++ b/ert_gui/plottery/plot_config_factory.py @@ -4,26 +4,15 @@ class PlotConfigFactory(object): @classmethod - def createPlotConfigForKey(cls, ert, key): + def createPlotConfigForKey(cls, key_def): """ - @type ert: res.enkf.enkf_main.EnKFMain - @param key: str + @param key_def: dict with definition of a key @return: PlotConfig """ - plot_config = PlotConfig(plot_settings=None , title = key) - return PlotConfigFactory.updatePlotConfigForKey(ert, key, plot_config) + plot_config = PlotConfig(plot_settings=None , title = key_def["key"]) - - @classmethod - def updatePlotConfigForKey(cls, ert, key, plot_config): - """ - @type ert: res.enkf.enkf_main.EnKFMain - @param key: str - @return: PlotConfig - """ - key_manager = ert.getKeyManager() # The styling of statistics changes based on the nature of the data - if key_manager.isSummaryKey(key) or key_manager.isGenDataKey(key): + if key_def["dimensionality"] == 2: mean_style = plot_config.getStatisticsStyle("mean") mean_style.line_style = "-" plot_config.setStatisticsStyle("mean", mean_style) @@ -31,7 +20,7 @@ def updatePlotConfigForKey(cls, ert, key, plot_config): p10p90_style = plot_config.getStatisticsStyle("p10-p90") p10p90_style.line_style = "--" plot_config.setStatisticsStyle("p10-p90", p10p90_style) - else: + elif key_def["dimensionality"] == 1: mean_style = plot_config.getStatisticsStyle("mean") mean_style.line_style = "-" mean_style.marker = "o" diff --git a/ert_gui/plottery/plot_context.py b/ert_gui/plottery/plot_context.py index b9c82179a33..81cbc5fe94f 100644 --- a/ert_gui/plottery/plot_context.py +++ b/ert_gui/plottery/plot_context.py @@ -1,5 +1,4 @@ from .plot_config import PlotConfig -from .plot_data_gatherer import PlotDataGatherer class PlotContext(object): UNKNOWN_AXIS = None @@ -11,31 +10,21 @@ class PlotContext(object): DEPTH_AXIS = "DEPTH" AXIS_TYPES = [UNKNOWN_AXIS, COUNT_AXIS, DATE_AXIS, DENSITY_AXIS, DEPTH_AXIS, INDEX_AXIS, VALUE_AXIS] - def __init__(self, ert, figure, plot_config, cases, key, data_gatherer): + def __init__(self, plot_config, cases, key): super(PlotContext, self).__init__() - self._data_gatherer = data_gatherer self._key = key self._cases = cases - self._figure = figure - self._ert = ert self._plot_config = plot_config + self.refcase_data = None self._date_support_active = True self._x_axis = None self._y_axis = None - def figure(self): - """ :rtype: matplotlib.figure.Figure""" - return self._figure - def plotConfig(self): """ :rtype: PlotConfig """ return self._plot_config - def ert(self): - """ :rtype: res.enkf.EnKFMain""" - return self._ert - def cases(self): """ :rtype: list of str """ return self._cases @@ -44,10 +33,6 @@ def key(self): """ :rtype: str """ return self._key - def dataGatherer(self): - """ :rtype: PlotDataGatherer """ - return self._data_gatherer - def deactivateDateSupport(self): self._date_support_active = False diff --git a/ert_gui/plottery/plot_data_gatherer.py b/ert_gui/plottery/plot_data_gatherer.py deleted file mode 100644 index b1e5ff5945c..00000000000 --- a/ert_gui/plottery/plot_data_gatherer.py +++ /dev/null @@ -1,159 +0,0 @@ -from pandas import DataFrame -from res.enkf.export import GenKwCollector, SummaryCollector, GenDataCollector, SummaryObservationCollector, \ - GenDataObservationCollector, CustomKWCollector - - -class PlotDataGatherer(object): - - def __init__(self, dataGatherFunc, conditionFunc, refcaseGatherFunc=None, observationGatherFunc=None, historyGatherFunc=None): - super(PlotDataGatherer, self).__init__() - - self._dataGatherFunction = dataGatherFunc - self._conditionFunction = conditionFunc - self._refcaseGatherFunction = refcaseGatherFunc - self._observationGatherFunction = observationGatherFunc - self._historyGatherFunc = historyGatherFunc - - def hasHistoryGatherFunction(self): - """ :rtype: bool """ - return self._historyGatherFunc is not None - - def hasRefcaseGatherFunction(self): - """ :rtype: bool """ - return self._refcaseGatherFunction is not None - - def hasObservationGatherFunction(self): - """ :rtype: bool """ - return self._observationGatherFunction is not None - - def canGatherDataForKey(self, key): - """ :rtype: bool """ - return self._conditionFunction(key) - - def gatherData(self, ert, case, key): - """ :rtype: pandas.DataFrame """ - if not self.canGatherDataForKey(key): - raise UserWarning("Unable to gather data for key: %s" % key) - - return self._dataGatherFunction(ert, case, key) - - def gatherRefcaseData(self, ert, key): - """ :rtype: pandas.DataFrame """ - if not self.canGatherDataForKey(key) or not self.hasRefcaseGatherFunction(): - raise UserWarning("Unable to gather refcase data for key: %s" % key) - - return self._refcaseGatherFunction(ert, key) - - def gatherObservationData(self, ert, case, key): - """ :rtype: pandas.DataFrame """ - if not self.canGatherDataForKey(key) or not self.hasObservationGatherFunction(): - raise UserWarning("Unable to gather observation data for key: %s" % key) - - return self._observationGatherFunction(ert, case, key) - - def gatherHistoryData(self, ert, case, key): - """ :rtype: pandas.DataFrame """ - if not self.canGatherDataForKey(key) or not self.hasHistoryGatherFunction(): - raise UserWarning("Unable to gather history data for key: %s" % key) - - return self._historyGatherFunc(ert, case, key) - - - @staticmethod - def gatherGenKwData(ert, case, key): - """ :rtype: pandas.DataFrame """ - data = GenKwCollector.loadAllGenKwData(ert, case, [key]) - return data[key].dropna() - - @staticmethod - def gatherSummaryData(ert, case, key): - """ :rtype: pandas.DataFrame """ - data = SummaryCollector.loadAllSummaryData(ert, case, [key]) - if not data.empty: - data = data.reset_index() - - if any(data.duplicated()): - print("** Warning: The simulation data contains duplicate " - "timestamps. A possible explanation is that your " - "simulation timestep is less than a second.") - data = data.drop_duplicates() - - - data = data.pivot(index="Date", columns="Realization", values=key) - - return data #.dropna() - - @staticmethod - def gatherSummaryRefcaseData(ert, key): - refcase = ert.eclConfig().getRefcase() - - if refcase is None or key not in refcase: - return DataFrame() - - values = refcase.numpy_vector(key, report_only=False) - dates = refcase.numpy_dates - - data = DataFrame(zip(dates, values), columns=['Date', key]) - data.set_index("Date", inplace=True) - - return data.iloc[1:] - - @staticmethod - def gatherSummaryHistoryData(ert, case, key): - # create history key - if ":" in key: - head, tail = key.split(":", 2) - key = "%sH:%s" % (head, tail) - else: - key = "%sH" % key - - data = PlotDataGatherer.gatherSummaryRefcaseData(ert, key) - if data.empty and case is not None: - data = PlotDataGatherer.gatherSummaryData(ert, case, key) - - return data - - @staticmethod - def gatherSummaryObservationData(ert, case, key): - if ert.getKeyManager().isKeyWithObservations(key): - return SummaryObservationCollector.loadObservationData(ert, case, [key]).dropna() - else: - return DataFrame() - - - @staticmethod - def gatherGenDataData(ert, case, key): - """ :rtype: pandas.DataFrame """ - key, report_step = key.split("@", 1) - report_step = int(report_step) - try: - data = GenDataCollector.loadGenData(ert, case, key, report_step) - except ValueError: - data = DataFrame() - - return data.dropna() # removes all rows that has a NaN - - - @staticmethod - def gatherGenDataObservationData(ert, case, key_with_report_step): - """ :rtype: pandas.DataFrame """ - key, report_step = key_with_report_step.split("@", 1) - report_step = int(report_step) - - obs_key = GenDataObservationCollector.getObservationKeyForDataKey(ert, key, report_step) - - if obs_key is not None: - obs_data = GenDataObservationCollector.loadGenDataObservations(ert, case, obs_key) - columns = {obs_key: key_with_report_step, "STD_%s" % obs_key: "STD_%s" % key_with_report_step} - obs_data = obs_data.rename(columns=columns) - else: - obs_data = DataFrame() - - return obs_data.dropna() - - @staticmethod - def gatherCustomKwData(ert, case, key): - """ :rtype: pandas.DataFrame """ - data = CustomKWCollector.loadAllCustomKWData(ert, case, [key])[key] - - return data diff --git a/ert_gui/plottery/plots/__init__.py b/ert_gui/plottery/plots/__init__.py index 64f84b1e9ac..e69de29bb2d 100644 --- a/ert_gui/plottery/plots/__init__.py +++ b/ert_gui/plottery/plots/__init__.py @@ -1,14 +0,0 @@ -import os -import matplotlib - -from .histogram import plotHistogram -from .gaussian_kde import plotGaussianKDE - -from .refcase import plotRefcase -from .history import plotHistory -from .observations import plotObservations - -from .ensemble import plotEnsemble -from .statistics import plotStatistics -from .distribution import plotDistribution -from .ccsp import plotCrossCaseStatistics diff --git a/ert_gui/plottery/plots/ccsp.py b/ert_gui/plottery/plots/ccsp.py index 52679222d79..7b81932eaca 100644 --- a/ert_gui/plottery/plots/ccsp.py +++ b/ert_gui/plottery/plots/ccsp.py @@ -3,12 +3,20 @@ from .plot_tools import PlotTools import pandas as pd -def plotCrossCaseStatistics(plot_context): + +class CrossCaseStatisticsPlot(object): + + def __init__(self): + self.dimensionality = 1 + + def plot(self, figure, plot_context, case_to_data_map, _observation_data): + plotCrossCaseStatistics(figure,plot_context, case_to_data_map, _observation_data) + +def plotCrossCaseStatistics(figure, plot_context, case_to_data_map, _observation_data): """ @type plot_context: ert_gui.plottery.PlotContext """ - ert = plot_context.ert() key = plot_context.key() config = plot_context.plotConfig() - axes = plot_context.figure().add_subplot(111) + axes = figure.add_subplot(111) """:type: matplotlib.axes.Axes """ plot_context.deactivateDateSupport() @@ -33,9 +41,8 @@ def plotCrossCaseStatistics(plot_context): "p67": {}, "p90": {} } - for case_index, case in enumerate(case_list): + for case_index, (case, data) in enumerate(case_to_data_map.items()): case_indexes.append(case_index) - data = plot_context.dataGatherer().gatherData(ert, case, key) std_dev_factor = config.getStandardDeviationFactor() if not data.empty: @@ -68,7 +75,7 @@ def plotCrossCaseStatistics(plot_context): axes.set_xticklabels([""] + case_list + [""], rotation=rotation) - PlotTools.finalizePlot(plot_context, axes, default_x_label="Case", default_y_label="Value") + PlotTools.finalizePlot(plot_context, figure, axes, default_x_label="Case", default_y_label="Value") def _addStatisticsLegends(plot_config): diff --git a/ert_gui/plottery/plots/distribution.py b/ert_gui/plottery/plots/distribution.py index af0b546a1d3..fbaf55efc90 100644 --- a/ert_gui/plottery/plots/distribution.py +++ b/ert_gui/plottery/plots/distribution.py @@ -1,12 +1,20 @@ from .plot_tools import PlotTools import pandas as pd -def plotDistribution(plot_context): + +class DistributionPlot(object): + + def __init__(self): + self.dimensionality = 1 + + def plot(self, figure, plot_context, case_to_data_map, _observation_data): + plotDistribution(figure, plot_context, case_to_data_map, _observation_data) + +def plotDistribution(figure, plot_context, case_to_data_map, _observation_data): """ @type plot_context: ert_gui.plottery.PlotContext """ - ert = plot_context.ert() key = plot_context.key() config = plot_context.plotConfig() - axes = plot_context.figure().add_subplot(111) + axes = figure.add_subplot(111) """:type: matplotlib.axes.Axes """ plot_context.deactivateDateSupport() @@ -20,9 +28,8 @@ def plotDistribution(plot_context): case_list = plot_context.cases() case_indexes = [] previous_data = None - for case_index, case in enumerate(case_list): + for case_index, (case, data) in enumerate(case_to_data_map.items()): case_indexes.append(case_index) - data = plot_context.dataGatherer().gatherData(ert, case, key) if not data.empty: _plotDistribution(axes, config, data, case, case_index, previous_data) @@ -40,7 +47,7 @@ def plotDistribution(plot_context): config.setLegendEnabled(False) - PlotTools.finalizePlot(plot_context, axes, default_x_label="Case", default_y_label="Value") + PlotTools.finalizePlot(plot_context, figure, axes, default_x_label="Case", default_y_label="Value") def _plotDistribution(axes, plot_config, data, label, index, previous_data): diff --git a/ert_gui/plottery/plots/ensemble.py b/ert_gui/plottery/plots/ensemble.py index 4afd1105c11..1d82ca1f628 100644 --- a/ert_gui/plottery/plots/ensemble.py +++ b/ert_gui/plottery/plots/ensemble.py @@ -1,59 +1,80 @@ -from .refcase import plotRefcase -from .history import plotHistory from .observations import plotObservations from .plot_tools import PlotTools -def plotEnsemble(plot_context): - """ - @type plot_context: ert_gui.plottery.PlotContext - """ - ert = plot_context.ert() - key = plot_context.key() - config = plot_context.plotConfig() - """:type: ert_gui.plottery.PlotConfig """ - axes = plot_context.figure().add_subplot(111) - """:type: matplotlib.axes.Axes """ - case_list = plot_context.cases() +class EnsemblePlot(object): - plot_context.y_axis = plot_context.VALUE_AXIS - plot_context.x_axis = plot_context.DATE_AXIS + def __init__(self): + self.dimensionality = 2 - for case in case_list: - data = plot_context.dataGatherer().gatherData(ert, case, key) - if not data.empty: - if not data.index.is_all_dates: - plot_context.deactivateDateSupport() - plot_context.x_axis = plot_context.INDEX_AXIS + def plot(self, figure, plot_context, case_to_data_map, observation_data): + """ + @type plot_context: ert_gui.plottery.PlotContext + """ + config = plot_context.plotConfig() + """:type: ert_gui.plottery.PlotConfig """ + axes = figure.add_subplot(111) + """:type: matplotlib.axes.Axes """ - _plotLines(axes, config, data, case, plot_context.isDateSupportActive()) - config.nextColor() + case_list = plot_context.cases() - plotRefcase(plot_context, axes) - plotObservations(plot_context, axes) - plotHistory(plot_context, axes) + plot_context.y_axis = plot_context.VALUE_AXIS + plot_context.x_axis = plot_context.DATE_AXIS - default_x_label = "Date" if plot_context.isDateSupportActive() else "Index" - PlotTools.finalizePlot(plot_context, axes, default_x_label=default_x_label, default_y_label="Value") + for case, data in case_to_data_map.items(): + data = data.T + if not data.empty: + if not data.columns.is_all_dates: + plot_context.deactivateDateSupport() + plot_context.x_axis = plot_context.INDEX_AXIS -def _plotLines(axes, plot_config, data, ensemble_label, is_date_supported): - """ - @type axes: matplotlib.axes.Axes - @type plot_config: ert_gui.plottery.PlotConfig - @type data: pandas.DataFrame - @type ensemble_label: Str - """ + self._plotLines(axes, config, data, case, plot_context.isDateSupportActive()) + config.nextColor() - style = plot_config.defaultStyle() + self.plotRefcase(plot_context, axes) + plotObservations(observation_data, plot_context, axes) - if len(data) == 1 and style.marker == '': - style.marker = '.' + default_x_label = "Date" if plot_context.isDateSupportActive() else "Index" + PlotTools.finalizePlot(plot_context, figure, axes, default_x_label=default_x_label, default_y_label="Value") - if is_date_supported: - lines = axes.plot_date(x=data.index.values, y=data, color=style.color, alpha=style.alpha, marker=style.marker, linestyle=style.line_style, linewidth=style.width, markersize=style.size) - else: - lines = axes.plot(data.index.values, data, color=style.color, alpha=style.alpha, marker=style.marker, linestyle=style.line_style, linewidth=style.width, markersize=style.size) - if len(lines) > 0: - plot_config.addLegendItem(ensemble_label, lines[0]) + def _plotLines(self, axes, plot_config, data, ensemble_label, is_date_supported): + """ + @type axes: matplotlib.axes.Axes + @type plot_config: ert_gui.plottery.PlotConfig + @type data: pandas.DataFrame + @type ensemble_label: Str + """ + + style = plot_config.defaultStyle() + + if len(data) == 1 and style.marker == '': + style.marker = '.' + + if is_date_supported: + lines = axes.plot_date(x=data.index.values, y=data, color=style.color, alpha=style.alpha, marker=style.marker, linestyle=style.line_style, linewidth=style.width, markersize=style.size) + else: + lines = axes.plot(data.index.values, data, color=style.color, alpha=style.alpha, marker=style.marker, linestyle=style.line_style, linewidth=style.width, markersize=style.size) + + if len(lines) > 0: + plot_config.addLegendItem(ensemble_label, lines[0]) + + def plotRefcase(self, plot_context, axes): + plot_config = plot_context.plotConfig() + + if (not plot_config.isRefcaseEnabled() + or plot_context.refcase_data is None + or plot_context.refcase_data.empty): + return + + data = plot_context.refcase_data + style = plot_config.refcaseStyle() + + lines = axes.plot_date(x=data.index.values, y=data, color=style.color, alpha=style.alpha, + marker=style.marker, linestyle=style.line_style, linewidth=style.width, + markersize=style.size) + + if len(lines) > 0 and style.isVisible(): + plot_config.addLegendItem("Refcase", lines[0]) + diff --git a/ert_gui/plottery/plots/gaussian_kde.py b/ert_gui/plottery/plots/gaussian_kde.py index 5e6c6b71242..42825edaed3 100644 --- a/ert_gui/plottery/plots/gaussian_kde.py +++ b/ert_gui/plottery/plots/gaussian_kde.py @@ -1,17 +1,23 @@ import numpy from scipy.stats import gaussian_kde from .plot_tools import PlotTools -import pandas as pd -def plotGaussianKDE(plot_context): +class GaussianKDEPlot(object): + def __init__(self): + self.dimensionality = 1 + + def plot(self, figure, plot_context, case_to_data_map, _observation_data): + plotGaussianKDE(figure, plot_context, case_to_data_map, _observation_data) + + +def plotGaussianKDE(figure, plot_context, case_to_data_map, _observation_data): """ @type plot_context: ert_gui.plottery.PlotContext """ - ert = plot_context.ert() key = plot_context.key() config = plot_context.plotConfig() - axes = plot_context.figure().add_subplot(111) + axes = figure.add_subplot(111) """:type: matplotlib.axes.Axes """ plot_context.deactivateDateSupport() @@ -22,15 +28,12 @@ def plotGaussianKDE(plot_context): key = key[6:] axes.set_xscale("log") - case_list = plot_context.cases() - for case in case_list: - data = plot_context.dataGatherer().gatherData(ert, case, key) - + for case, data in case_to_data_map.items(): if not data.empty and data.nunique() > 1: _plotGaussianKDE(axes, config, data, case) config.nextColor() - PlotTools.finalizePlot(plot_context, axes, default_x_label="Value", default_y_label="Density") + PlotTools.finalizePlot(plot_context, figure, axes, default_x_label="Value", default_y_label="Density") def _plotGaussianKDE(axes, plot_config, data, label): @@ -43,21 +46,12 @@ def _plotGaussianKDE(axes, plot_config, data, label): style = plot_config.histogramStyle() - if data.dtype == "object": - try: - data = pd.to_numeric(data, errors='coerce') - except AttributeError: - data = data.convert_objects(convert_numeric=True) - - if data.dtype == "object": - pass - else: - sample_range = data.max() - data.min() - indexes = numpy.linspace(data.min() - 0.5 * sample_range, data.max() + 0.5 * sample_range, 1000) - gkde = gaussian_kde(data.values) - evaluated_gkde = gkde.evaluate(indexes) + sample_range = data.max() - data.min() + indexes = numpy.linspace(data.min() - 0.5 * sample_range, data.max() + 0.5 * sample_range, 1000) + gkde = gaussian_kde(data.values) + evaluated_gkde = gkde.evaluate(indexes) - lines = axes.plot(indexes, evaluated_gkde, linewidth=style.width, color=style.color, alpha=style.alpha) + lines = axes.plot(indexes, evaluated_gkde, linewidth=style.width, color=style.color, alpha=style.alpha) - if len(lines) > 0: - plot_config.addLegendItem(label, lines[0]) \ No newline at end of file + if len(lines) > 0: + plot_config.addLegendItem(label, lines[0]) diff --git a/ert_gui/plottery/plots/histogram.py b/ert_gui/plottery/plots/histogram.py index eeb308c10ae..533b45384a4 100644 --- a/ert_gui/plottery/plots/histogram.py +++ b/ert_gui/plottery/plots/histogram.py @@ -4,9 +4,17 @@ from .plot_tools import PlotTools import pandas as pd -def plotHistogram(plot_context): + +class HistogramPlot(object): + + def __init__(self): + self.dimensionality = 1 + + def plot(self, figure, plot_context, case_to_data_map, _observation_data): + plotHistogram(figure, plot_context, case_to_data_map, _observation_data) + +def plotHistogram(figure, plot_context, case_to_data_map, _observation_data): """ @type plot_context: ert_gui.plottery.PlotContext """ - ert = plot_context.ert() key = plot_context.key() config = plot_context.plotConfig() @@ -33,8 +41,8 @@ def plotHistogram(plot_context): categories = set() max_element_count = 0 categorical = False - for case in case_list: - data[case] = plot_context.dataGatherer().gatherData(ert, case, key) + for case, datas in case_to_data_map.items(): + data[case] = datas if data[case].dtype == "object": try: @@ -48,15 +56,17 @@ def plotHistogram(plot_context): if categorical: categories = categories.union(set(data[case].unique())) else: + current_min = data[case].min() + current_max = data[case].max() if minimum is None: - minimum = data[case].min() + minimum = current_min else: - minimum = min(minimum, data[case].min()) + minimum = min(minimum, current_min) if maximum is None: - maximum = data[case].max() + maximum = current_max else: - maximum = max(maximum, data[case].max()) + maximum = max(maximum, current_max) max_element_count = max(max_element_count, len(data[case].index)) @@ -66,7 +76,7 @@ def plotHistogram(plot_context): axes = {} """:type: dict of (str, matplotlib.axes.Axes) """ for index, case in enumerate(case_list): - axes[case] = plot_context.figure().add_subplot(case_count, 1, index + 1) + axes[case] = figure.add_subplot(case_count, 1, index + 1) axes[case].set_title("%s (%s)" % (config.title(), case)) diff --git a/ert_gui/plottery/plots/observations.py b/ert_gui/plottery/plots/observations.py index 68d050b68ad..413f092a365 100644 --- a/ert_gui/plottery/plots/observations.py +++ b/ert_gui/plottery/plots/observations.py @@ -1,19 +1,15 @@ import math -def plotObservations(plot_context, axes): - ert = plot_context.ert() +def plotObservations(observation_data, plot_context, axes): key = plot_context.key() config = plot_context.plotConfig() case_list = plot_context.cases() - data_gatherer = plot_context.dataGatherer() - - if config.isObservationsEnabled() and data_gatherer.hasObservationGatherFunction(): - if len(case_list) > 0: - observation_data = data_gatherer.gatherObservationData(ert, case_list[0], key) - - if not observation_data.empty: - _plotObservations(axes, config, observation_data, value_column=key) + if (config.isObservationsEnabled() + and len(case_list) > 0 + and observation_data is not None + and not observation_data.empty): + _plotObservations(axes, config, observation_data, value_column=key) def _plotObservations(axes, plot_config, data, value_column): @@ -39,8 +35,8 @@ def cap_size(line_with): if style.line_style == '': style.width = 0 - errorbars = axes.errorbar(x=data.index.values, y=data[value_column].values, - yerr=data["STD_%s" % value_column].values, + errorbars = axes.errorbar(x=data.columns.get_level_values("key_index").values, y=data.loc["OBS"].values, + yerr=data.loc["STD"].values, fmt=style.line_style, ecolor=style.color, color=style.color, capsize=cap_size(style.width), capthick=style.width, #same as width/thickness on error line diff --git a/ert_gui/plottery/plots/plot_tools.py b/ert_gui/plottery/plots/plot_tools.py index 7e2a3b93a29..79a6636918e 100644 --- a/ert_gui/plottery/plots/plot_tools.py +++ b/ert_gui/plottery/plots/plot_tools.py @@ -60,7 +60,7 @@ def _getYAxisLimits(plot_context): @staticmethod - def finalizePlot(plot_context, axes, default_x_label="Unnamed", default_y_label="Unnamed"): + def finalizePlot(plot_context, figure, axes, default_x_label="Unnamed", default_y_label="Unnamed"): """ @type plot_context: ert_gui.plottery.PlotContext @type axes: @@ -88,12 +88,12 @@ def finalizePlot(plot_context, axes, default_x_label="Unnamed", default_y_label= axes.set_title(plot_config.title()) if plot_context.isDateSupportActive(): - plot_context.figure().autofmt_xdate() + figure.autofmt_xdate() @staticmethod def __setupLabels(plot_context, default_x_label, default_y_label): - ert = plot_context.ert() + #ert = plot_context.ert() key = plot_context.key() config = plot_context.plotConfig() @@ -103,7 +103,7 @@ def __setupLabels(plot_context, default_x_label, default_y_label): if config.yLabel() is None: config.setYLabel(default_y_label) - if ert.eclConfig().hasRefcase() and key in ert.eclConfig().getRefcase(): - unit = ert.eclConfig().getRefcase().unit(key) - if unit != "": - config.setYLabel(unit) \ No newline at end of file + #if ert.eclConfig().hasRefcase() and key in ert.eclConfig().getRefcase(): + #unit = ert.eclConfig().getRefcase().unit(key) + #if unit != "": + #config.setYLabel(unit) \ No newline at end of file diff --git a/ert_gui/plottery/plots/refcase.py b/ert_gui/plottery/plots/refcase.py deleted file mode 100644 index 4a0836bf1b6..00000000000 --- a/ert_gui/plottery/plots/refcase.py +++ /dev/null @@ -1,26 +0,0 @@ -def plotRefcase(plot_context, axes): - ert = plot_context.ert() - key = plot_context.key() - config = plot_context.plotConfig() - data_gatherer = plot_context.dataGatherer() - - if config.isRefcaseEnabled() and data_gatherer.hasRefcaseGatherFunction(): - refcase_data = data_gatherer.gatherRefcaseData(ert, key) - - if not refcase_data.empty: - _plotRefcase(axes, config, refcase_data) - - -def _plotRefcase(axes, plot_config, data): - """ - @type axes: matplotlib.axes.Axes - @type plot_config: PlotConfig - @type data: DataFrame - """ - - style = plot_config.refcaseStyle() - - lines = axes.plot_date(x=data.index.values, y=data, color=style.color, alpha=style.alpha, marker=style.marker, linestyle=style.line_style, linewidth=style.width, markersize=style.size) - - if len(lines) > 0 and style.isVisible(): - plot_config.addLegendItem("Refcase", lines[0]) \ No newline at end of file diff --git a/ert_gui/plottery/plots/statistics.py b/ert_gui/plottery/plots/statistics.py index 79b6f990670..034450cb0f3 100644 --- a/ert_gui/plottery/plots/statistics.py +++ b/ert_gui/plottery/plots/statistics.py @@ -1,61 +1,64 @@ from matplotlib.patches import Rectangle from matplotlib.lines import Line2D from pandas import DataFrame -from .refcase import plotRefcase + from .observations import plotObservations from .plot_tools import PlotTools -def plotStatistics(plot_context): - """ @type plot_context: ert_gui.plottery.PlotContext """ - ert = plot_context.ert() - key = plot_context.key() - config = plot_context.plotConfig() - """:type: ert_gui.plotter.PlotConfig """ - axes = plot_context.figure().add_subplot(111) - """:type: matplotlib.axes.Axes """ - - plot_context.y_axis = plot_context.VALUE_AXIS - plot_context.x_axis = plot_context.DATE_AXIS - - case_list = plot_context.cases() - for case in case_list: - data = plot_context.dataGatherer().gatherData(ert, case, key) - - if not data.empty: - if not data.index.is_all_dates: - plot_context.deactivateDateSupport() - plot_context.x_axis = plot_context.INDEX_AXIS - - style = config.getStatisticsStyle("mean") - rectangle = Rectangle((0, 0), 1, 1, color=style.color, alpha=0.8) # creates rectangle patch for legend use. - config.addLegendItem(case, rectangle) - - statistics_data = DataFrame() - std_dev_factor = config.getStandardDeviationFactor() - - statistics_data["Minimum"] = data.min(axis=1) - statistics_data["Maximum"] = data.max(axis=1) - statistics_data["Mean"] = data.mean(axis=1) - statistics_data["p10"] = data.quantile(0.1, axis=1) - statistics_data["p33"] = data.quantile(0.33, axis=1) - statistics_data["p50"] = data.quantile(0.50, axis=1) - statistics_data["p67"] = data.quantile(0.67, axis=1) - statistics_data["p90"] = data.quantile(0.90, axis=1) - std = data.std(axis=1) * std_dev_factor - statistics_data["std+"] = statistics_data["Mean"] + std - statistics_data["std-"] = statistics_data["Mean"] - std - - _plotPercentiles(axes, config, statistics_data, case) - config.nextColor() - - _addStatisticsLegends(plot_config=config) - - plotRefcase(plot_context, axes) - plotObservations(plot_context, axes) - - default_x_label = "Date" if plot_context.isDateSupportActive() else "Index" - PlotTools.finalizePlot(plot_context, axes, default_x_label=default_x_label, default_y_label="Value") +class StatisticsPlot(object): + + def __init__(self): + self.dimensionality = 2 + + def plot(self, figure, plot_context, case_to_data_map, _observation_data): + """ @type plot_context: ert_gui.plottery.PlotContext """ + key = plot_context.key() + config = plot_context.plotConfig() + """:type: ert_gui.plotter.PlotConfig """ + axes = figure.add_subplot(111) + """:type: matplotlib.axes.Axes """ + + plot_context.y_axis = plot_context.VALUE_AXIS + plot_context.x_axis = plot_context.DATE_AXIS + + for case, data in case_to_data_map.items(): + data = data.T + if not data.empty: + if not data.index.is_all_dates: + plot_context.deactivateDateSupport() + plot_context.x_axis = plot_context.INDEX_AXIS + + + style = config.getStatisticsStyle("mean") + rectangle = Rectangle((0, 0), 1, 1, color=style.color, alpha=0.8) # creates rectangle patch for legend use. + config.addLegendItem(case, rectangle) + + statistics_data = DataFrame() + std_dev_factor = config.getStandardDeviationFactor() + + statistics_data["Minimum"] = data.min(axis=1) + statistics_data["Maximum"] = data.max(axis=1) + statistics_data["Mean"] = data.mean(axis=1) + statistics_data["p10"] = data.quantile(0.1, axis=1) + statistics_data["p33"] = data.quantile(0.33, axis=1) + statistics_data["p50"] = data.quantile(0.50, axis=1) + statistics_data["p67"] = data.quantile(0.67, axis=1) + statistics_data["p90"] = data.quantile(0.90, axis=1) + std = data.std(axis=1) * std_dev_factor + statistics_data["std+"] = statistics_data["Mean"] + std + statistics_data["std-"] = statistics_data["Mean"] - std + + _plotPercentiles(axes, config, statistics_data, case) + config.nextColor() + + _addStatisticsLegends(plot_config=config) + + #plotRefcase(plot_context, axes) + plotObservations(_observation_data, plot_context, axes) + + default_x_label = "Date" if plot_context.isDateSupportActive() else "Index" + PlotTools.finalizePlot(plot_context, figure, axes, default_x_label=default_x_label, default_y_label="Value") def _addStatisticsLegends(plot_config): _addStatisticsLegend(plot_config, "mean") diff --git a/ert_gui/simulation/run_dialog.py b/ert_gui/simulation/run_dialog.py index ba702794250..1a098e942d1 100644 --- a/ert_gui/simulation/run_dialog.py +++ b/ert_gui/simulation/run_dialog.py @@ -74,7 +74,7 @@ def __init__(self, config_file, run_model, arguments, parent=None): self.running_time = QLabel("") self.plot_tool = PlotTool(config_file) - self.plot_tool.setParent(None) + self.plot_tool.setParent(self) self.plot_button = QPushButton(self.plot_tool.getName()) self.plot_button.clicked.connect(self.plot_tool.trigger) self.plot_button.setEnabled(ert is not None) diff --git a/ert_gui/tools/plot/customize/customize_plot_dialog.py b/ert_gui/tools/plot/customize/customize_plot_dialog.py index 382c495e178..0daff38adb8 100644 --- a/ert_gui/tools/plot/customize/customize_plot_dialog.py +++ b/ert_gui/tools/plot/customize/customize_plot_dialog.py @@ -4,7 +4,7 @@ from ert_shared import ERT from ert_gui.tools.plot.widgets import CopyStyleToDialog from ert_gui.ertwidgets import resourceIcon -from ert_gui.plottery import PlotConfig, PlotConfigHistory +from ert_gui.plottery import PlotConfig, PlotConfigHistory, PlotConfigFactory from ert_gui.tools.plot.customize import DefaultCustomizationView, StyleCustomizationView, \ StatisticsCustomizationView, LimitsCustomizationView @@ -13,21 +13,19 @@ class PlotCustomizer(QObject): settingsChanged = Signal() - def __init__(self, parent, default_plot_settings=None): + def __init__(self, parent, key_defs): super(PlotCustomizer, self).__init__() self._plot_config_key = None self._previous_key = None - self.default_plot_settings = default_plot_settings + self.default_plot_settings = None self._plot_configs = { None: PlotConfigHistory( "No_Key_Selected", - PlotConfig(plot_settings=default_plot_settings, title=None)) + PlotConfig(plot_settings=None, title=None)) } - self._plotConfigCreator = self._defaultPlotConfigCreator - - self._customization_dialog = CustomizePlotDialog("Customize", parent, key=self._plot_config_key) + self._customization_dialog = CustomizePlotDialog("Customize", parent, key_defs, key=self._plot_config_key) self._customization_dialog.addTab("general", "General", DefaultCustomizationView()) self._customization_dialog.addTab("style", "Style", StyleCustomizationView()) @@ -130,16 +128,11 @@ def toggleCustomizationDialog(self): else: self._customization_dialog.show() - def _defaultPlotConfigCreator(self, title): - return PlotConfig(title) - - def _selectiveCopyOfCurrentPlotConfig(self, title): - return self._plotConfigCreator(title) - - def switchPlotConfigHistory(self, key): + def switchPlotConfigHistory(self, key_def): + key = key_def["key"] if key != self._plot_config_key: if not key in self._plot_configs: - self._plot_configs[key] = PlotConfigHistory(key, self._selectiveCopyOfCurrentPlotConfig(key)) + self._plot_configs[key] = PlotConfigHistory(key, PlotConfigFactory.createPlotConfigForKey(key_def)) self._customization_dialog.addCopyableKey(key) self._customization_dialog.currentPlotKeyChanged(key) self._previous_key = self._plot_config_key @@ -153,9 +146,6 @@ def getPlotConfig(self): def setAxisTypes(self, x_axis_type, y_axis_type): self._customize_limits.setAxisTypes(x_axis_type, y_axis_type) - def setPlotConfigCreator(self, func): - self._plotConfigCreator = func - class CustomizePlotDialog(QDialog): applySettings = Signal() @@ -165,18 +155,12 @@ class CustomizePlotDialog(QDialog): copySettings = Signal(str) copySettingsToOthers = Signal(list) - def __init__(self, title, parent=None, key=''): + def __init__(self, title, parent, key_defs, key=''): QDialog.__init__(self, parent) self.setWindowTitle(title) - self._ert = ERT.ert - - """:type: res.enkf.enkf_main.EnKFMain""" - - self.key_manager = self._ert.getKeyManager() - """:type: res.enkf.key_manager.KeyManager """ - self.current_key = key + self._key_defs = key_defs self.setWindowFlags(self.windowFlags() & ~Qt.WindowContextHelpButtonHint) self.setWindowFlags(self.windowFlags() & ~Qt.WindowCloseButtonHint) @@ -255,8 +239,7 @@ def __init__(self, title, parent=None, key=''): self.setLayout(layout) def initiateCopyStyleToDialog(self): - all_other_keys = [k for k in self.key_manager.allDataTypeKeys() if k != self.current_key] - dialog = CopyStyleToDialog(self, self.current_key, all_other_keys) + dialog = CopyStyleToDialog(self, self.current_key, self._key_defs) if dialog.exec_(): self.copySettingsToOthers.emit(dialog.getSelectedKeys()) diff --git a/ert_gui/tools/plot/data_type_keys_list_model.py b/ert_gui/tools/plot/data_type_keys_list_model.py index 5974681abcd..f84a27f19cd 100644 --- a/ert_gui/tools/plot/data_type_keys_list_model.py +++ b/ert_gui/tools/plot/data_type_keys_list_model.py @@ -9,17 +9,14 @@ class DataTypeKeysListModel(QAbstractItemModel): HAS_OBSERVATIONS = QColor(237, 218, 116) GROUP_ITEM = QColor(64, 64, 64) - def __init__(self, ert): + def __init__(self, keys): """ @type ert: res.enkf.EnKFMain """ QAbstractItemModel.__init__(self) - self.__ert = ert + self._keys = keys self.__icon = resourceIcon("ide/small/bullet_star") - def keyManager(self): - return self.__ert.getKeyManager() - def index(self, row, column, parent=None, *args, **kwargs): return self.createIndex(row, column) @@ -27,7 +24,7 @@ def parent(self, index=None): return QModelIndex() def rowCount(self, parent=None, *args, **kwargs): - return len(self.keyManager().allDataTypeKeys()) + return len(self._keys) def columnCount(self, QModelIndex_parent=None, *args, **kwargs): return 1 @@ -36,14 +33,14 @@ def data(self, index, role=None): assert isinstance(index, QModelIndex) if index.isValid(): - items = self.keyManager().allDataTypeKeys() + items = self._keys row = index.row() item = items[row] if role == Qt.DisplayRole: - return item + return item["key"] elif role == Qt.BackgroundRole: - if self.keyManager().isKeyWithObservations(item): + if len(item["observations"]) > 0: return self.HAS_OBSERVATIONS def itemAt(self, index): @@ -51,25 +48,6 @@ def itemAt(self, index): if index.isValid(): row = index.row() - return self.keyManager().allDataTypeKeys()[row] + return self._keys[row] return None - - - def isSummaryKey(self, key): - return self.keyManager().isSummaryKey(key) - - def isBlockKey(self, key): - return False - - def isGenKWKey(self, key): - return self.keyManager().isGenKwKey(key) - - def isGenDataKey(self, key): - return self.keyManager().isGenDataKey(key) - - def isCustomKwKey(self, key): - return self.keyManager().isCustomKwKey(key) - - def isCustomPcaKey(self, key): - return False diff --git a/ert_gui/tools/plot/data_type_keys_widget.py b/ert_gui/tools/plot/data_type_keys_widget.py index 63c77098603..c47d04028ba 100644 --- a/ert_gui/tools/plot/data_type_keys_widget.py +++ b/ert_gui/tools/plot/data_type_keys_widget.py @@ -8,16 +8,16 @@ class DataTypeKeysWidget(QWidget): dataTypeKeySelected = Signal() - def __init__(self, model): + def __init__(self, key_defs): QWidget.__init__(self) - self.__filter_popup = FilterPopup(self) + self.__filter_popup = FilterPopup(self, key_defs) self.__filter_popup.filterSettingsChanged.connect(self.onItemChanged) layout = QVBoxLayout() - self.model = model - self.filter_model = DataTypeProxyModel(self.model) + self.model = DataTypeKeysListModel(key_defs) + self.filter_model = DataTypeProxyModel(self, self.model) filter_layout = QHBoxLayout() @@ -46,12 +46,8 @@ def __init__(self, model): def onItemChanged(self, item): # self.filter_model.setShowBlockKeys(item["block"]) - self.filter_model.setShowSummaryKeys(item["summary"]) - self.filter_model.setShowGenKWKeys(item["gen_kw"]) - self.filter_model.setShowGenDataKeys(item["gen_data"]) - self.filter_model.setShowCustomKwKeys(item["custom_kw"]) - # self.filter_model.setShowCustomPcaKeys(item["custom_pca"]) - + for value, visible in item.items(): + self.filter_model.setFilterOnMetadata("data_origin", value, visible) def itemSelected(self): selected_item = self.getSelectedItem() diff --git a/ert_gui/tools/plot/data_type_proxy_model.py b/ert_gui/tools/plot/data_type_proxy_model.py index 751aa72d2f3..59bf8af2c20 100644 --- a/ert_gui/tools/plot/data_type_proxy_model.py +++ b/ert_gui/tools/plot/data_type_proxy_model.py @@ -21,15 +21,16 @@ class DataTypeProxyModel(QSortFilterProxyModel): - def __init__(self, model , parent=None): + def __init__(self, parent, model): QSortFilterProxyModel.__init__(self, parent) + self.__show_summary_keys = True self.__show_block_keys = True self.__show_gen_kw_keys = True self.__show_gen_data_keys = True self.__show_custom_kw_keys = True self.__show_custom_pca_keys = True - + self._metadata_filters = {} self.setFilterCaseSensitivity(Qt.CaseInsensitive) self.setSourceModel(model) @@ -41,24 +42,10 @@ def filterAcceptsRow(self, index, q_model_index): source_index = source_model.index(index, 0, q_model_index) key = source_model.itemAt(source_index) - if not self.__show_summary_keys and source_model.isSummaryKey(key): - show = False - - elif not self.__show_block_keys and source_model.isBlockKey(key): - show = False - - elif not self.__show_gen_kw_keys and source_model.isGenKWKey(key): - show = False - - elif not self.__show_gen_data_keys and source_model.isGenDataKey(key): - show = False - - elif not self.__show_custom_kw_keys and source_model.isCustomKwKey(key): - show = False - - elif not self.__show_custom_pca_keys and source_model.isCustomPcaKey(key): - show = False - + for meta_key, values in self._metadata_filters.items(): + for value, visible in values.items(): + if not visible and meta_key in key["metadata"] and key["metadata"][meta_key] == value: + show = False return show @@ -66,27 +53,11 @@ def sourceModel(self): """ @rtype: DataTypeKeysListModel """ return QSortFilterProxyModel.sourceModel(self) - def setShowSummaryKeys(self, visible): - self.__show_summary_keys = visible - self.invalidateFilter() - - def setShowBlockKeys(self, visible): - self.__show_block_keys = visible - self.invalidateFilter() - - def setShowGenKWKeys(self, visible): - self.__show_gen_kw_keys = visible - self.invalidateFilter() - - def setShowGenDataKeys(self, visible): - self.__show_gen_data_keys = visible - self.invalidateFilter() + def setFilterOnMetadata(self, key, value, visible): + if not key in self._metadata_filters: + self._metadata_filters[key] = {} - def setShowCustomKwKeys(self, visible): - self.__show_custom_kw_keys = visible + self._metadata_filters[key][value] = visible self.invalidateFilter() - def setShowCustomPcaKeys(self, visible): - self.__show_custom_pca_keys = visible - self.invalidateFilter() diff --git a/ert_gui/tools/plot/filter_popup.py b/ert_gui/tools/plot/filter_popup.py index e2f6602cdb3..77359b7fb87 100644 --- a/ert_gui/tools/plot/filter_popup.py +++ b/ert_gui/tools/plot/filter_popup.py @@ -5,7 +5,7 @@ class FilterPopup(QDialog): filterSettingsChanged = Signal(dict) - def __init__(self, parent=None): + def __init__(self, parent, key_defs): QDialog.__init__(self, parent, Qt.WindowStaysOnTopHint | Qt.X11BypassWindowManagerHint | Qt.FramelessWindowHint) self.setVisible(False) @@ -21,12 +21,10 @@ def __init__(self, parent=None): self.__layout.setSizeConstraint(QLayout.SetFixedSize) self.__layout.addWidget(QLabel("Filter by data type:")) - self.addFilterItem("Summary", "summary") - # self.addFilterItem("Block", "block") - self.addFilterItem("Gen KW", "gen_kw") - self.addFilterItem("Gen Data", "gen_data") - self.addFilterItem("Custom KW", "custom_kw") - # self.addFilterItem("Custom PCA", "custom_pca") + + filters = {k["metadata"]["data_origin"] for k in key_defs} + for f in filters: + self.addFilterItem(f, f) frame.setLayout(self.__layout) diff --git a/ert_gui/tools/plot/filterable_kw_list_model.py b/ert_gui/tools/plot/filterable_kw_list_model.py index 0d0f83f3abe..957939defdd 100644 --- a/ert_gui/tools/plot/filterable_kw_list_model.py +++ b/ert_gui/tools/plot/filterable_kw_list_model.py @@ -5,51 +5,26 @@ class FilterableKwListModel(SelectableListModel): """ Adds ERT - plotting keyword specific filtering functionality to the general SelectableListModel """ - def __init__(self, ert, selectable_keys): - SelectableListModel.__init__(self, selectable_keys) - self._ert = ert - self._show_summary_keys = True - self._show_gen_kw_keys = True - self._show_gen_data_keys = True - self._show_custom_kw_keys = True - - def getList(self): - filtered_list = [] - for item in self._items: - if self._show_summary_keys and self.isSummaryKey(item): - filtered_list.append(item) - elif self._show_gen_kw_keys and self.isGenKWKey(item): - filtered_list.append(item) - elif self._show_gen_data_keys and self.isGenDataKey(item): - filtered_list.append(item) - elif self._show_custom_kw_keys and self.isCustomKwKey(item): - filtered_list.append(item) - - return filtered_list - - def keyManager(self): - return self._ert.getKeyManager() - - def isSummaryKey(self, key): - return self.keyManager().isSummaryKey(key) - - def isGenKWKey(self, key): - return self.keyManager().isGenKwKey(key) - - def isGenDataKey(self, key): - return self.keyManager().isGenDataKey(key) + def __init__(self, key_defs): + SelectableListModel.__init__(self, [k["key"] for k in key_defs]) + self._key_defs = key_defs + self._metadata_filters = {} - def isCustomKwKey(self, key): - return self.keyManager().isCustomKwKey(key) - - def setShowSummaryKeys(self, visible): - self._show_summary_keys = visible - - def setShowGenKWKeys(self, visible): - self._show_gen_kw_keys = visible - - def setShowGenDataKeys(self, visible): - self._show_gen_data_keys = visible - - def setShowCustomKwKeys(self, visible): - self._show_custom_kw_keys = visible + def getList(self): + items = [] + for item in self._key_defs: + add = True + for meta_key, meta_value in item["metadata"].items(): + if (meta_key in self._metadata_filters + and not self._metadata_filters[meta_key][meta_value]): + add = False + + if add: + items.append(item["key"]) + return items + + def setFilterOnMetadata(self, key, value, visible): + if key not in self._metadata_filters: + self._metadata_filters[key] = {} + + self._metadata_filters[key][value] = visible \ No newline at end of file diff --git a/ert_gui/tools/plot/plot_api.py b/ert_gui/tools/plot/plot_api.py new file mode 100644 index 00000000000..21ab8b91934 --- /dev/null +++ b/ert_gui/tools/plot/plot_api.py @@ -0,0 +1,135 @@ +from ert_data import loader as loader +import pandas as pd + + +class PlotApi(object): + + def __init__(self, facade): + self._facade = facade + + def all_data_type_keys(self): + """ Returns a list of all the keys except observation keys. For each key a dict is returned with info about + the key""" + return [{"key": key, + "index_type": self._key_index_type(key), + "observations": self._facade.observation_keys(key), + "has_refcase": self._facade.has_refcase(key), + "dimensionality": self._dimensionality_of_key(key), + "metadata": self._metadata(key)} + for key in self._facade.all_data_type_keys()] + + def _metadata(self, key): + meta = {} + if self._facade.is_summary_key(key): + meta["data_origin"] = "Summary" + elif self._facade.is_gen_data_key(key): + meta["data_origin"] = "Gen Data" + elif self._facade.is_gen_kw_key(key): + meta["data_origin"] = "Gen KW" + elif self._facade.is_custom_kw_key(key): + meta["data_origin"] = "Custom Data" + return meta + + def get_all_cases_not_running(self): + """ Returns a list of all cases that are not running. For each case a dict with info about the case is + returned """ + facade = self._facade + return [{"name": case, + "hidden": facade.is_case_hidden(case), + "has_data": facade.case_has_data(case)} + for case + in facade.cases() + if not facade.is_case_running(case)] + + def data_for_key(self, case, key): + """ Returns a pandas DataFrame with the datapoints for a given key for a given case. The row index is + the realization number, and the column index is a multi-index with (key, index/date)""" + + if self._facade.is_summary_key(key): + data = self._facade.gather_summary_data(case, key).T + elif self._facade.is_gen_kw_key(key): + data = self._facade.gather_gen_kw_data(case, key) + elif self._facade.is_custom_kw_key(key): + data = self._facade.gather_custom_kw_data(case, key) + elif self._facade.is_gen_data_key(key): + data = self._facade.gather_gen_data_data(case, key).T + else: + raise ValueError("no such key {}".format(key)) + + data = pd.concat({key: data}, axis=1) + + try: + return data.astype(float) + except ValueError: + return data + + def observations_for_obs_keys(self, case, obs_keys): + """ Returns a pandas DataFrame with the datapoints for a given observation key for a given case. The row index + is the realization number, and the column index is a multi-index with (obs_key, index/date, obs_index), + where index/date is used to relate the observation to the data point it relates to, and obs_index is + the index for the observation itself""" + measured_data = pd.DataFrame() + case_name = case + + for key in obs_keys: + observation_type = self._facade.get_impl_type_name_for_obs_key(key) + data_loader = loader.data_loader_factory(observation_type) + + data = data_loader(self._facade, key, case_name, include_data=False) + + # Simulated data and observations both refer to the data + # index at some levels, so having that information available is + # helpful + self._add_index_range(data) + + data = pd.concat({key: data}, axis=1, names=["obs_key"]) + + measured_data = pd.concat([measured_data, data], axis=1) + + data = measured_data.astype(float) + expected_keys = ["OBS", "STD"] + if not isinstance(data, pd.DataFrame): + raise TypeError( + "Invalid type: {}, should be type: {}".format(type(data), pd.DataFrame) + ) + elif not data.empty and not set(expected_keys).issubset(data.index): + raise ValueError( + "{} should be present in DataFrame index, missing: {}".format( + ["OBS", "STD"], set(expected_keys) - set(data.index) + ) + ) + else: + return data + + def _add_index_range(self, data): + """ + Adds a second column index with which corresponds to the data + index. This is because in libres simulated data and observations + are connected through an observation key and data index, so having + that information available when the data is joined is helpful. + """ + arrays = [data.columns.to_list(), list(range(len(data.columns)))] + tuples = list(zip(*arrays)) + index = pd.MultiIndex.from_tuples(tuples, names=['key_index', 'data_index']) + data.columns = index + + def refcase_data(self, key): + """ Returns a pandas DataFrame with the data points for the refcase for a given data key, if any. + The row index is the index/date and the column index is the key.""" + return self._facade.refcase_data(key) + + def _dimensionality_of_key(self, key): + if self._facade.is_summary_key(key) or self._facade.is_gen_data_key(key): + return 2 + else: + return 1 + + def _key_index_type(self, key): + if self._facade.is_gen_data_key(key): + return "INDEX" + elif self._facade.is_summary_key(key): + return "VALUE" + else: + return None + + diff --git a/ert_gui/tools/plot/plot_case_model.py b/ert_gui/tools/plot/plot_case_model.py index ba80c57af8e..0d4f77030aa 100644 --- a/ert_gui/tools/plot/plot_case_model.py +++ b/ert_gui/tools/plot/plot_case_model.py @@ -6,9 +6,9 @@ class PlotCaseModel(QAbstractItemModel): - def __init__(self): + def __init__(self, cases): QAbstractItemModel.__init__(self) - self.__data = None + self.__data = cases def index(self, row, column, parent=None, *args, **kwargs): return self.createIndex(row, column) @@ -46,9 +46,6 @@ def itemAt(self, index): def getAllItems(self): - if self.__data is None: - self.__data = getAllCasesNotRunning() - return self.__data def __iter__(self): diff --git a/ert_gui/tools/plot/plot_case_selection_widget.py b/ert_gui/tools/plot/plot_case_selection_widget.py index 497625c3714..e259d745fc7 100644 --- a/ert_gui/tools/plot/plot_case_selection_widget.py +++ b/ert_gui/tools/plot/plot_case_selection_widget.py @@ -10,10 +10,10 @@ class CaseSelectionWidget(QWidget): caseSelectionChanged = Signal() - def __init__(self, current_case): + def __init__(self, case_names): QWidget.__init__(self) - - self.__model = PlotCaseModel() + self._cases = case_names + self.__model = PlotCaseModel(case_names) self.__signal_mapper = QSignalMapper(self) self.__case_selectors = {} @@ -38,7 +38,7 @@ def __init__(self, current_case): self.__case_layout.setContentsMargins(0, 0, 0, 0) layout.addLayout(self.__case_layout) - self.addCaseSelector(disabled=True, current_case=current_case) + self.addCaseSelector(disabled=True) layout.addStretch() self.setLayout(layout) @@ -64,7 +64,7 @@ def checkCaseCount(self): self.__add_case_button.setEnabled(state) - def addCaseSelector(self, disabled=False, current_case=None): + def addCaseSelector(self, disabled=False): widget = QWidget() layout = QHBoxLayout() @@ -76,18 +76,8 @@ def addCaseSelector(self, disabled=False, current_case=None): combo.setMinimumContentsLength(20) combo.setModel(self.__model) - if current_case is not None: - index = 0 - for item in self.__model: - if item == current_case: - combo.setCurrentIndex(index) - break - index += 1 - combo.currentIndexChanged.connect(self.caseSelectionChanged.emit) - - layout.addWidget(combo, 1) button = QToolButton() diff --git a/ert_gui/tools/plot/plot_widget.py b/ert_gui/tools/plot/plot_widget.py index bd27dc52072..c70a91ef311 100644 --- a/ert_gui/tools/plot/plot_widget.py +++ b/ert_gui/tools/plot/plot_widget.py @@ -31,17 +31,14 @@ def __init__(self, canvas, parent, coordinates=True): break - class PlotWidget(QWidget): customizationTriggered = Signal() - def __init__(self, name, plotFunction, plot_condition_function_list, plotContextFunction, parent=None): + def __init__(self, name, plotter, parent=None): QWidget.__init__(self, parent) self._name = name - self._plotFunction = plotFunction - self._plotContextFunction = plotContextFunction - self._plot_conditions = plot_condition_function_list + self._plotter = plotter """:type: list of functions """ self._figure = Figure() @@ -62,12 +59,6 @@ def __init__(self, name, plotFunction, plot_condition_function_list, plotContext self._active = False self.resetPlot() - - def getFigure(self): - """ :rtype: matplotlib.figure.Figure""" - return self._figure - - def resetPlot(self): self._figure.clear() @@ -76,37 +67,16 @@ def name(self): """ @rtype: str """ return self._name - def updatePlot(self): - if self.isDirty() and self.isActive(): - # print("Drawing: %s" % self._name) - self.resetPlot() - plot_context = self._plotContextFunction(self.getFigure()) - try: - self._plotFunction(plot_context) - self._canvas.draw() - except Exception as e: - exc_type, exc_value, exc_tb = sys.exc_info() - sys.stderr.write("%s\n" % ("-" * 80)) - traceback.print_tb(exc_tb) - sys.stderr.write("Exception type: %s\n" % exc_type.__name__) - sys.stderr.write("%s\n" % e) - sys.stderr.write("%s\n" % ("-" * 80)) - sys.stderr.write("An error occurred during plotting. This stack trace is helpful for diagnosing the problem.") - - self.setDirty(False) - - - def setDirty(self, dirty=True): - self._dirty = dirty - - def isDirty(self): - return self._dirty - - def setActive(self, active=True): - self._active = active - - def isActive(self): - return self._active - - def canPlotKey(self, key): - return any([plotConditionFunction(key) for plotConditionFunction in self._plot_conditions]) + def updatePlot(self, plot_context, case_to_data_map, observations): + self.resetPlot() + try: + self._plotter.plot(self._figure, plot_context, case_to_data_map, observations) + self._canvas.draw() + except Exception as e: + exc_type, exc_value, exc_tb = sys.exc_info() + sys.stderr.write("%s\n" % ("-" * 80)) + traceback.print_tb(exc_tb) + sys.stderr.write("Exception type: %s\n" % exc_type.__name__) + sys.stderr.write("%s\n" % e) + sys.stderr.write("%s\n" % ("-" * 80)) + sys.stderr.write("An error occurred during plotting. This stack trace is helpful for diagnosing the problem.") diff --git a/ert_gui/tools/plot/plot_window.py b/ert_gui/tools/plot/plot_window.py index 3e3161096cc..e2f18394d42 100644 --- a/ert_gui/tools/plot/plot_window.py +++ b/ert_gui/tools/plot/plot_window.py @@ -1,15 +1,21 @@ from qtpy.QtCore import Qt from qtpy.QtWidgets import QMainWindow, QDockWidget, QTabWidget, QWidget, QVBoxLayout - +from ert_gui.plottery.plots.ccsp import CrossCaseStatisticsPlot +from ert_gui.plottery.plots.distribution import DistributionPlot +from ert_gui.plottery.plots.ensemble import EnsemblePlot +from ert_gui.plottery.plots.gaussian_kde import GaussianKDEPlot +from ert_gui.plottery.plots.histogram import HistogramPlot +from ert_gui.plottery.plots.statistics import StatisticsPlot from ert_shared import ERT from ert_gui.ertwidgets import showWaitCursorWhileWaiting -from ert_gui.ertwidgets.models.ertmodel import getCurrentCaseName -from ert_gui.plottery import PlotContext, PlotDataGatherer as PDG, PlotConfig, plots, PlotConfigFactory +from ert_gui.plottery import PlotContext, PlotConfig -from ert_gui.tools.plot import DataTypeKeysWidget, CaseSelectionWidget, PlotWidget, DataTypeKeysListModel +from ert_gui.tools.plot import DataTypeKeysWidget, CaseSelectionWidget, PlotWidget from ert_gui.tools.plot.customize import PlotCustomizer +from ert_gui.tools.plot.plot_api import PlotApi + CROSS_CASE_STATISTICS = "Cross Case Statistics" DISTRIBUTION = "Distribution" GAUSSIAN_KDE = "Gaussian KDE" @@ -23,28 +29,19 @@ class PlotWindow(QMainWindow): def __init__(self, config_file, parent): QMainWindow.__init__(self, parent) - self._ert = ERT.ert - """:type: res.enkf.enkf_main.EnKFMain""" - - key_manager = self._ert.getKeyManager() - """:type: res.enkf.key_manager.KeyManager """ + self._api = PlotApi(ERT.enkf_facade) self.setMinimumWidth(850) self.setMinimumHeight(650) self.setWindowTitle("Plotting - {}".format(config_file)) self.activateWindow() + self._key_definitions = self._api.all_data_type_keys() + self._plot_customizer = PlotCustomizer(self, self._key_definitions) - self._plot_customizer = PlotCustomizer(self) - - def plotConfigCreator(key): - return PlotConfigFactory.createPlotConfigForKey(self._ert, key) - - self._plot_customizer.setPlotConfigCreator(plotConfigCreator) self._plot_customizer.settingsChanged.connect(self.keySelected) self._central_tab = QTabWidget() - self._central_tab.currentChanged.connect(self.currentPlotChanged) central_widget = QWidget() central_layout = QVBoxLayout() @@ -58,69 +55,62 @@ def plotConfigCreator(key): self._plot_widgets = [] """:type: list of PlotWidget""" - self._data_gatherers = [] - """:type: list of PlotDataGatherer """ - - summary_gatherer = self.createDataGatherer(PDG.gatherSummaryData, key_manager.isSummaryKey, refcaseGatherFunc=PDG.gatherSummaryRefcaseData, observationGatherFunc=PDG.gatherSummaryObservationData, historyGatherFunc=PDG.gatherSummaryHistoryData) - gen_data_gatherer = self.createDataGatherer(PDG.gatherGenDataData, key_manager.isGenDataKey, observationGatherFunc=PDG.gatherGenDataObservationData) - gen_kw_gatherer = self.createDataGatherer(PDG.gatherGenKwData, key_manager.isGenKwKey) - custom_kw_gatherer = self.createDataGatherer(PDG.gatherCustomKwData, key_manager.isCustomKwKey) + self.addPlotWidget(ENSEMBLE, EnsemblePlot()) + self.addPlotWidget(STATISTICS, StatisticsPlot()) + self.addPlotWidget(HISTOGRAM, HistogramPlot()) + self.addPlotWidget(GAUSSIAN_KDE, GaussianKDEPlot()) + self.addPlotWidget(DISTRIBUTION, DistributionPlot()) + self.addPlotWidget(CROSS_CASE_STATISTICS, CrossCaseStatisticsPlot()) + self._central_tab.currentChanged.connect(self.currentPlotChanged) - self.addPlotWidget(ENSEMBLE, plots.plotEnsemble, [summary_gatherer, gen_data_gatherer]) - self.addPlotWidget(STATISTICS, plots.plotStatistics, [summary_gatherer, gen_data_gatherer]) - self.addPlotWidget(HISTOGRAM, plots.plotHistogram, [gen_kw_gatherer, custom_kw_gatherer]) - self.addPlotWidget(GAUSSIAN_KDE, plots.plotGaussianKDE, [gen_kw_gatherer, custom_kw_gatherer]) - self.addPlotWidget(DISTRIBUTION, plots.plotDistribution, [gen_kw_gatherer, custom_kw_gatherer]) - self.addPlotWidget(CROSS_CASE_STATISTICS, plots.plotCrossCaseStatistics, [gen_kw_gatherer, custom_kw_gatherer]) + cases = self._api.get_all_cases_not_running() + case_names = [case["name"] for case in cases if not case["hidden"]] - data_types_key_model = DataTypeKeysListModel(self._ert) - self._data_type_keys_widget = DataTypeKeysWidget(data_types_key_model) + self._data_type_keys_widget = DataTypeKeysWidget(self._key_definitions) self._data_type_keys_widget.dataTypeKeySelected.connect(self.keySelected) self.addDock("Data types", self._data_type_keys_widget) - - current_case = getCurrentCaseName() - self._case_selection_widget = CaseSelectionWidget(current_case) + self._case_selection_widget = CaseSelectionWidget(case_names) self._case_selection_widget.caseSelectionChanged.connect(self.keySelected) self.addDock("Plot case", self._case_selection_widget) current_plot_widget = self._plot_widgets[self._central_tab.currentIndex()] - current_plot_widget.setActive() self._data_type_keys_widget.selectDefault() self._updateCustomizer(current_plot_widget) - - - def createDataGatherer(self, dataGatherFunc, gatherConditionFunc, refcaseGatherFunc=None, observationGatherFunc=None, historyGatherFunc=None): - data_gatherer = PDG(dataGatherFunc, gatherConditionFunc, refcaseGatherFunc=refcaseGatherFunc, observationGatherFunc=observationGatherFunc, historyGatherFunc=historyGatherFunc) - self._data_gatherers.append(data_gatherer) - return data_gatherer - - def currentPlotChanged(self): + key_def = self.getSelectedKey() + key = key_def["key"] + for plot_widget in self._plot_widgets: - plot_widget.setActive(False) index = self._central_tab.indexOf(plot_widget) - if index == self._central_tab.currentIndex() and plot_widget.canPlotKey(self.getSelectedKey()): - plot_widget.setActive() + if index == self._central_tab.currentIndex() \ + and plot_widget._plotter.dimensionality == key_def["dimensionality"]: self._updateCustomizer(plot_widget) - plot_widget.updatePlot() + cases = self._case_selection_widget.getPlotCaseNames() + case_to_data_map = {case: self._api.data_for_key(case, key)[key] for case in cases} + if len(key_def["observations"]) > 0: + observations = self._api.observations_for_obs_keys(cases[0], key_def["observations"]) + else: + observations = None - def _updateCustomizer(self, plot_widget): - """ @type plot_widget: PlotWidget """ - key = self.getSelectedKey() - key_manager = self._ert.getKeyManager() + plot_config = PlotConfig.createCopy(self._plot_customizer.getPlotConfig()) + plot_config.setTitle(key) + plot_context = PlotContext(plot_config, cases, key) + + if key_def["has_refcase"]: + plot_context.refcase_data = self._api.refcase_data(key) - index_type = PlotContext.UNKNOWN_AXIS + plot_widget.updatePlot(plot_context, case_to_data_map, observations) - if key_manager.isGenDataKey(key): - index_type = PlotContext.INDEX_AXIS - elif key_manager.isSummaryKey(key): - index_type = PlotContext.DATE_AXIS + def _updateCustomizer(self, plot_widget): + """ @type plot_widget: PlotWidget """ + key_def = self.getSelectedKey() + index_type = key_def["index_type"] x_axis_type = PlotContext.UNKNOWN_AXIS y_axis_type = PlotContext.UNKNOWN_AXIS @@ -144,25 +134,11 @@ def _updateCustomizer(self, plot_widget): self._plot_customizer.setAxisTypes(x_axis_type, y_axis_type) - - def createPlotContext(self, figure): - key = self.getSelectedKey() - cases = self._case_selection_widget.getPlotCaseNames() - data_gatherer = self.getDataGathererForKey(key) - plot_config = PlotConfig.createCopy(self._plot_customizer.getPlotConfig()) - plot_config.setTitle(key) - return PlotContext(self._ert, figure, plot_config, cases, key, data_gatherer) - - def getDataGathererForKey(self, key): - """ @rtype: PlotDataGatherer """ - return next((data_gatherer for data_gatherer in self._data_gatherers if data_gatherer.canGatherDataForKey(key)), None) - def getSelectedKey(self): - return str(self._data_type_keys_widget.getSelectedItem()) + return self._data_type_keys_widget.getSelectedItem() - def addPlotWidget(self, name, plotFunction, data_gatherers, enabled=True): - plot_condition_function_list = [data_gatherer.canGatherDataForKey for data_gatherer in data_gatherers] - plot_widget = PlotWidget(name, plotFunction, plot_condition_function_list, self.createPlotContext) + def addPlotWidget(self, name, plotter, enabled=True): + plot_widget = PlotWidget(name, plotter) plot_widget.customizationTriggered.connect(self.toggleCustomizeDialog) index = self._central_tab.addTab(plot_widget, name) @@ -183,17 +159,15 @@ def addDock(self, name, widget, area=Qt.LeftDockWidgetArea, allowed_areas=Qt.All @showWaitCursorWhileWaiting def keySelected(self): - key = self.getSelectedKey() - self._plot_customizer.switchPlotConfigHistory(key) + key_def = self.getSelectedKey() + self._plot_customizer.switchPlotConfigHistory(key_def) for plot_widget in self._plot_widgets: - plot_widget.setDirty() index = self._central_tab.indexOf(plot_widget) - self._central_tab.setTabEnabled(index, plot_widget.canPlotKey(key)) + self._central_tab.setTabEnabled(index, plot_widget._plotter.dimensionality == key_def["dimensionality"]) + + self.currentPlotChanged() - for plot_widget in self._plot_widgets: - if plot_widget.canPlotKey(key): - plot_widget.updatePlot() def toggleCustomizeDialog(self): diff --git a/ert_gui/tools/plot/widgets/copy_style_to_dialog.py b/ert_gui/tools/plot/widgets/copy_style_to_dialog.py index 6cd1dbd0663..3976728fb00 100644 --- a/ert_gui/tools/plot/widgets/copy_style_to_dialog.py +++ b/ert_gui/tools/plot/widgets/copy_style_to_dialog.py @@ -7,7 +7,7 @@ class CopyStyleToDialog(QDialog): - def __init__(self, parent=None, current_key='', selectable_keys=[]): + def __init__(self, parent, current_key, key_defs): QWidget.__init__(self, parent) self.setMinimumWidth(450) self.setMinimumHeight(200) @@ -17,19 +17,14 @@ def __init__(self, parent=None, current_key='', selectable_keys=[]): layout = QFormLayout(self) - self._ert = ERT.ert - """:type: res.enkf.enkf_main.EnKFMain""" - - self.model = self._ert - - self._filter_popup = FilterPopup(self) + self._filter_popup = FilterPopup(self, key_defs) self._filter_popup.filterSettingsChanged.connect(self.filterSettingsChanged) filter_popup_button = QToolButton() filter_popup_button.setIcon(resourceIcon("ide/cog_edit.png")) filter_popup_button.clicked.connect(self._filter_popup.show) - self._list_model = FilterableKwListModel(self._ert, selectable_keys) + self._list_model = FilterableKwListModel(key_defs) self._list_model.unselectAll() self._cl = CheckList(self._list_model, custom_filter_button=filter_popup_button) @@ -56,8 +51,6 @@ def getSelectedKeys(self): return self._list_model.getSelectedItems() def filterSettingsChanged(self, item): - self._list_model.setShowSummaryKeys(item["summary"]) - self._list_model.setShowGenKWKeys(item["gen_kw"]) - self._list_model.setShowGenDataKeys(item["gen_data"]) - self._list_model.setShowCustomKwKeys(item["custom_kw"]) + for value, visible in item.items(): + self._list_model.setFilterOnMetadata("data_origin", value, visible) self._cl.modelChanged() diff --git a/ert_shared/libres_facade.py b/ert_shared/libres_facade.py index 92dac73ae5b..a338b24d2ee 100644 --- a/ert_shared/libres_facade.py +++ b/ert_shared/libres_facade.py @@ -1,8 +1,10 @@ +from pandas import DataFrame from res.analysis.analysis_module import AnalysisModule from res.analysis.enums.analysis_module_options_enum import \ AnalysisModuleOptionsEnum from res.enkf.export import (GenDataCollector, SummaryCollector, - SummaryObservationCollector) + SummaryObservationCollector, GenDataObservationCollector, GenKwCollector, + CustomKWCollector) from res.enkf.plot_data import PlotBlockDataLoader @@ -81,3 +83,120 @@ def select_or_create_new_case(self, case_name): if self.get_current_case_name() != case_name: fs = self._enkf_main.getEnkfFsManager().getFileSystem(case_name) self._enkf_main.getEnkfFsManager().switchFileSystem(fs) + + def cases(self): + return self._enkf_main.getEnkfFsManager().getCaseList() + + def is_case_hidden(self, case): + return self._enkf_main.getEnkfFsManager().isCaseHidden(case) + + def case_has_data(self, case): + return self._enkf_main.getEnkfFsManager().caseHasData(case) + + def is_case_running(self, case): + return self._enkf_main.getEnkfFsManager().isCaseRunning(case) + + def all_data_type_keys(self): + return self._enkf_main.getKeyManager().allDataTypeKeys() + + def observation_keys(self, key): + if self._enkf_main.getKeyManager().isGenDataKey(key): + key_parts = key.split("@") + key = key_parts[0] + if len(key_parts) > 1: + report_step = int(key_parts[1]) + else: + report_step = 0 + + obs_key = GenDataObservationCollector.getObservationKeyForDataKey(self._enkf_main, key, report_step) + if obs_key is not None: + return [obs_key] + else: + return [] + elif self._enkf_main.getKeyManager().isSummaryKey(key): + return [str(k) for k in self._enkf_main.ensembleConfig().getNode(key).getObservationKeys()] + else: + return [] + + def gather_gen_kw_data(self, case, key): + """ :rtype: pandas.DataFrame """ + data = GenKwCollector.loadAllGenKwData(self._enkf_main, case, [key]) + if key in data: + return data[key].dropna() + else: + return DataFrame() + + def gather_summary_data(self, case, key): + """ :rtype: pandas.DataFrame """ + data = SummaryCollector.loadAllSummaryData(self._enkf_main, case, [key]) + if not data.empty: + data = data.reset_index() + + if any(data.duplicated()): + print("** Warning: The simulation data contains duplicate " + "timestamps. A possible explanation is that your " + "simulation timestep is less than a second.") + data = data.drop_duplicates() + + data = data.pivot(index="Date", columns="Realization", values=key) + + return data + + def has_refcase(self, key): + refcase = self._enkf_main.eclConfig().getRefcase() + return refcase is not None and key in refcase + + def refcase_data(self, key): + refcase = self._enkf_main.eclConfig().getRefcase() + + if refcase is None or key not in refcase: + return DataFrame() + + values = refcase.numpy_vector(key, report_only=False) + dates = refcase.numpy_dates + + data = DataFrame(zip(dates, values), columns=['Date', key]) + data.set_index("Date", inplace=True) + + return data.iloc[1:] + + def gather_gen_data_data(self, case, key): + """ :rtype: pandas.DataFrame """ + key_parts = key.split("@") + key = key_parts[0] + if len(key_parts) > 1: + report_step = int(key_parts[1]) + else: + report_step = 0 + + try: + data = GenDataCollector.loadGenData(self._enkf_main, case, key, report_step) + except (ValueError, KeyError): + data = DataFrame() + + return data.dropna() # removes all rows that has a NaN + + def gather_custom_kw_data(self, case, key): + """ :rtype: pandas.DataFrame """ + data = CustomKWCollector.loadAllCustomKWData(self._enkf_main, case, [key]) + + if key in data: + return data[key] + else: + return data + + def is_summary_key(self, key): + """ :rtype: bool """ + return key in self._enkf_main.getKeyManager().summaryKeys() + + def is_gen_kw_key(self, key): + """ :rtype: bool """ + return key in self._enkf_main.getKeyManager().genKwKeys() + + def is_custom_kw_key(self, key): + """ :rtype: bool """ + return key in self._enkf_main.getKeyManager().customKwKeys() + + def is_gen_data_key(self, key): + """ :rtype: bool """ + return key in self._enkf_main.getKeyManager().genDataKeys() diff --git a/tests/gui/plottery/test_plot_data_gatherer.py b/tests/gui/plottery/test_plot_data_gatherer.py deleted file mode 100644 index 89b70ef5f9d..00000000000 --- a/tests/gui/plottery/test_plot_data_gatherer.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import datetime - -import pandas as pd - -from tests import ErtTest -from res.test import ErtTestContext -from ert_gui.plottery.plot_data_gatherer import PlotDataGatherer - - -class PlotGatherTest(ErtTest): - - def test_gatherSummaryRefcaseData(self): - config_file = self.createTestPath(os.path.join("local", "snake_oil", "snake_oil.ert")) - with ErtTestContext('SummaryRefcaseData', config_file) as work_area: - ert = work_area.getErt() - key = "WOPRH:OP1" - result_data = PlotDataGatherer.gatherSummaryRefcaseData(ert, key) - - expected_data = [ - (0, datetime.date(2010,1,2), 1.03836009657e-05), - (244, datetime.date(2010, 9, 3), 0.46973800659), - (1267, datetime.date(2013, 6, 22), 0.11672365665), - (-1, datetime.date(2015, 6, 23), 0.00820410997), - ] - - self.assertEqual(len(result_data), 1999) - - for index, date, value in expected_data: - self.assertAlmostEqual(value, result_data.iloc[index][key], delta=1E-10) - self.assertEqual(date, result_data.iloc[index].name.date()) - - key = "not_a_key" - - result_data = PlotDataGatherer.gatherSummaryRefcaseData(ert, key) - expected_data = pd.DataFrame() - - pd.testing.assert_frame_equal(result_data, expected_data, check_exact=True) diff --git a/tests/gui/tools/plot/test_plot_api.py b/tests/gui/tools/plot/test_plot_api.py new file mode 100644 index 00000000000..04dfe90fdce --- /dev/null +++ b/tests/gui/tools/plot/test_plot_api.py @@ -0,0 +1,125 @@ +import os +import shutil + +from pandas import DataFrame + +from ert_gui.tools.plot.plot_api import PlotApi +from res.enkf import EnKFMain, ResConfig + +from ert_shared.libres_facade import LibresFacade +from tests.utils import SOURCE_DIR, tmpdir +from unittest import TestCase + + +class PlotApiTest(TestCase): + + def api(self): + config_file = 'snake_oil.ert' + + rc = ResConfig(user_config_file=config_file) + rc.convertToCReference(None) + ert = EnKFMain(rc) + facade = LibresFacade(ert) + api = PlotApi(facade) + return api + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_all_keys_present(self): + api = self.api() + + key_defs = api.all_data_type_keys() + keys = {x["key"] for x in key_defs} + expected = {'BPR:1,3,8', 'BPR:445', 'BPR:5,5,5', 'BPR:721', 'FGIP', 'FGIPH', 'FGOR', 'FGORH', 'FGPR', 'FGPRH', + 'FGPT', 'FGPTH', 'FOIP', 'FOIPH', 'FOPR', 'FOPRH', 'FOPT', 'FOPTH', 'FWCT', 'FWCTH', 'FWIP', + 'FWIPH', 'FWPR', 'FWPRH', 'FWPT', 'FWPTH', 'TIME', 'WGOR:OP1', 'WGOR:OP2', 'WGORH:OP1', 'WGORH:OP2', + 'WGPR:OP1', 'WGPR:OP2', 'WGPRH:OP1', 'WGPRH:OP2', 'WOPR:OP1', 'WOPR:OP2', 'WOPRH:OP1', 'WOPRH:OP2', + 'WWCT:OP1', 'WWCT:OP2', 'WWCTH:OP1', 'WWCTH:OP2', 'WWPR:OP1', 'WWPR:OP2', 'WWPRH:OP1', 'WWPRH:OP2', + 'SNAKE_OIL_PARAM:BPR_138_PERSISTENCE', 'SNAKE_OIL_PARAM:BPR_555_PERSISTENCE', + 'SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE', 'SNAKE_OIL_PARAM:OP1_OCTAVES', 'SNAKE_OIL_PARAM:OP1_OFFSET', + 'SNAKE_OIL_PARAM:OP1_PERSISTENCE', 'SNAKE_OIL_PARAM:OP2_DIVERGENCE_SCALE', + 'SNAKE_OIL_PARAM:OP2_OCTAVES', 'SNAKE_OIL_PARAM:OP2_OFFSET', 'SNAKE_OIL_PARAM:OP2_PERSISTENCE', + 'SNAKE_OIL_NPV:NPV', 'SNAKE_OIL_NPV:RATING', 'SNAKE_OIL_GPR_DIFF@199', 'SNAKE_OIL_OPR_DIFF@199', + 'SNAKE_OIL_WPR_DIFF@199'} + self.assertSetEqual(expected, keys) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_observation_key_present(self): + api = self.api() + key_defs = api.all_data_type_keys() + expected_obs = { + 'FOPR': ['FOPR'], + 'WOPR:OP1': ['WOPR_OP1_108', 'WOPR_OP1_190', 'WOPR_OP1_144', 'WOPR_OP1_9', 'WOPR_OP1_72', 'WOPR_OP1_36'], + 'SNAKE_OIL_WPR_DIFF@199': ["WPR_DIFF_1"] + } + + for key_def in key_defs: + if key_def["key"] in expected_obs: + expected = expected_obs[key_def["key"]] + self.assertEqual(expected, key_def["observations"]) + else: + self.assertEqual(0, len(key_def["observations"])) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_can_load_data_and_observations(self): + api = self.api() + key_defs = api.all_data_type_keys() + cases = api.get_all_cases_not_running() + for case in cases: + for key_def in key_defs: + obs = key_def["observations"] + obs_data = api.observations_for_obs_keys(case["name"], obs) + data = api.data_for_key(case["name"], key_def["key"]) + + self.assertIsInstance(data, DataFrame) + self.assertTrue(not data.empty) + + self.assertIsInstance(obs_data, DataFrame) + if len(obs) > 0: + self.assertTrue(not obs_data.empty) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_no_storage(self): + shutil.rmtree("storage") + api = self.api() + key_defs = api.all_data_type_keys() + cases = api.get_all_cases_not_running() + for case in cases: + for key_def in key_defs: + obs = key_def["observations"] + obs_data = api.observations_for_obs_keys(case["name"], obs) + data = api.data_for_key(case["name"], key_def["key"]) + self.assertIsInstance(obs_data, DataFrame) + self.assertIsInstance(data, DataFrame) + self.assertTrue(data.empty) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_key_def_structure(self): + shutil.rmtree("storage") + api = self.api() + key_defs = api.all_data_type_keys() + fopr = next(x for x in key_defs if x["key"] == "FOPR") + + expected = { + 'dimensionality': 2, + 'has_refcase': True, + 'index_type': 'VALUE', + 'key': 'FOPR', + 'metadata': {'data_origin': 'Summary'}, + 'observations': ['FOPR'] + } + + self.assertEqual(expected, fopr) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_structure(self): + api = self.api() + cases = api.get_all_cases_not_running() + case = next(x for x in cases if x["name"] == "default_0") + + expected = { + 'has_data': True, + 'hidden': False, + 'name': 'default_0' + } + + self.assertEqual(expected, case) diff --git a/tests/shared/test_libres_facade.py b/tests/shared/test_libres_facade.py new file mode 100644 index 00000000000..8add1a87da8 --- /dev/null +++ b/tests/shared/test_libres_facade.py @@ -0,0 +1,166 @@ +import os +from pandas.core.base import PandasObject + +from res.enkf import EnKFMain, ResConfig + +from ert_shared.libres_facade import LibresFacade +from tests.utils import SOURCE_DIR, tmpdir +from unittest import TestCase + + +class LibresFacadeTest(TestCase): + + def facade(self): + config_file = 'snake_oil.ert' + + rc = ResConfig(user_config_file=config_file) + rc.convertToCReference(None) + ert = EnKFMain(rc) + facade = LibresFacade(ert) + return facade + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_keyword_type_checks(self): + facade = self.facade() + self.assertTrue(facade.is_custom_kw_key('SNAKE_OIL_NPV:NPV')) + self.assertTrue(facade.is_gen_data_key('SNAKE_OIL_GPR_DIFF@199')) + self.assertTrue(facade.is_summary_key('BPR:1,3,8')) + self.assertTrue(facade.is_gen_kw_key('SNAKE_OIL_PARAM:BPR_138_PERSISTENCE')) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_keyword_type_checks_missing_key(self): + facade = self.facade() + self.assertFalse(facade.is_custom_kw_key('nokey')) + self.assertFalse(facade.is_gen_data_key('nokey')) + self.assertFalse(facade.is_summary_key('nokey')) + self.assertFalse(facade.is_gen_kw_key('nokey')) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_data_fetching(self): + facade = self.facade() + data = [ + facade.gather_custom_kw_data('default_0', 'SNAKE_OIL_NPV:NPV'), + facade.gather_gen_data_data('default_0', 'SNAKE_OIL_GPR_DIFF@199'), + facade.gather_summary_data('default_0', 'BPR:1,3,8'), + facade.gather_gen_kw_data('default_0', 'SNAKE_OIL_PARAM:BPR_138_PERSISTENCE') + ] + + for dataframe in data: + self.assertIsInstance(dataframe, PandasObject) + self.assertFalse(dataframe.empty) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_data_fetching_missing_case(self): + facade = self.facade() + data = [ + facade.gather_custom_kw_data('nocase', 'SNAKE_OIL_NPV:NPV'), + facade.gather_gen_data_data('nocase', 'SNAKE_OIL_GPR_DIFF@199'), + facade.gather_summary_data('nocase', 'BPR:1,3,8'), + facade.gather_gen_kw_data('nocase', 'SNAKE_OIL_PARAM:BPR_138_PERSISTENCE') + ] + + for dataframe in data: + self.assertIsInstance(dataframe, PandasObject) + self.assertTrue(dataframe.empty) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_data_fetching_missing_key(self): + facade = self.facade() + data = [ + facade.gather_custom_kw_data('default_0', 'nokey'), + facade.gather_gen_data_data('default_0', 'nokey'), + facade.gather_summary_data('default_0', 'nokey'), + facade.gather_gen_kw_data('default_0', 'nokey') + ] + + for dataframe in data: + self.assertIsInstance(dataframe, PandasObject) + self.assertTrue(dataframe.empty) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_cases_list(self): + facade = self.facade() + cases = facade.cases() + self.assertEqual(["default_0", "default_1"], cases) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_is_hidden(self): + facade = self.facade() + self.assertFalse(facade.is_case_hidden("default_0")) + self.assertFalse(facade.is_case_hidden("nocase")) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_has_data(self): + facade = self.facade() + self.assertTrue(facade.case_has_data("default_0")) + self.assertFalse(facade.case_has_data("default")) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_is_running(self): + facade = self.facade() + self.assertFalse(facade.is_case_running("default_0")) + self.assertFalse(facade.is_case_running("nocase")) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_all_data_type_keys(self): + facade = self.facade() + keys = facade.all_data_type_keys() + + expected = ['BPR:1,3,8', 'BPR:445', 'BPR:5,5,5', 'BPR:721', 'FGIP', 'FGIPH', 'FGOR', 'FGORH', 'FGPR', 'FGPRH', + 'FGPT', 'FGPTH', 'FOIP', 'FOIPH', 'FOPR', 'FOPRH', 'FOPT', 'FOPTH', 'FWCT', 'FWCTH', 'FWIP', + 'FWIPH', 'FWPR', 'FWPRH', 'FWPT', 'FWPTH', 'TIME', 'WGOR:OP1', 'WGOR:OP2', 'WGORH:OP1', 'WGORH:OP2', + 'WGPR:OP1', 'WGPR:OP2', 'WGPRH:OP1', 'WGPRH:OP2', 'WOPR:OP1', 'WOPR:OP2', 'WOPRH:OP1', 'WOPRH:OP2', + 'WWCT:OP1', 'WWCT:OP2', 'WWCTH:OP1', 'WWCTH:OP2', 'WWPR:OP1', 'WWPR:OP2', 'WWPRH:OP1', 'WWPRH:OP2', + 'SNAKE_OIL_PARAM:BPR_138_PERSISTENCE', 'SNAKE_OIL_PARAM:BPR_555_PERSISTENCE', + 'SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE', 'SNAKE_OIL_PARAM:OP1_OCTAVES', 'SNAKE_OIL_PARAM:OP1_OFFSET', + 'SNAKE_OIL_PARAM:OP1_PERSISTENCE', 'SNAKE_OIL_PARAM:OP2_DIVERGENCE_SCALE', + 'SNAKE_OIL_PARAM:OP2_OCTAVES', 'SNAKE_OIL_PARAM:OP2_OFFSET', 'SNAKE_OIL_PARAM:OP2_PERSISTENCE', + 'SNAKE_OIL_NPV:NPV', 'SNAKE_OIL_NPV:RATING', 'SNAKE_OIL_GPR_DIFF@199', 'SNAKE_OIL_OPR_DIFF@199', + 'SNAKE_OIL_WPR_DIFF@199'] + + self.assertEqual(expected, keys) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_observation_keys(self): + facade = self.facade() + + expected_obs = { + 'FOPR': ['FOPR'], + 'WOPR:OP1': ['WOPR_OP1_108', 'WOPR_OP1_190', 'WOPR_OP1_144', 'WOPR_OP1_9', 'WOPR_OP1_72', 'WOPR_OP1_36'], + 'SNAKE_OIL_WPR_DIFF@199': ["WPR_DIFF_1"] + } + + for key in facade.all_data_type_keys(): + obs_keys = facade.observation_keys(key) + expected = [] + if key in expected_obs: + expected = expected_obs[key] + self.assertEqual(expected, obs_keys) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_observation_keys_missing_key(self): + facade = self.facade() + obs_keys = facade.observation_keys("nokey") + self.assertEqual([], obs_keys) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_has_refcase(self): + facade = self.facade() + self.assertTrue(facade.has_refcase('FOPR')) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_has_refcase_missing_key(self): + facade = self.facade() + self.assertFalse(facade.has_refcase('nokey')) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_refcase_data(self): + facade = self.facade() + data = facade.refcase_data('FOPR') + self.assertIsInstance(data, PandasObject) + + @tmpdir(os.path.join(SOURCE_DIR, 'test-data/local/snake_oil')) + def test_case_refcase_data_missing_key(self): + facade = self.facade() + data = facade.refcase_data('nokey') + self.assertIsInstance(data, PandasObject)