diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 63b30987a..469340344 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -325,7 +325,7 @@ def plot( /, *, show: bool = True, - explanation_idx: int | None = None, + explanation_index: int | None = None, **kwargs: Any, ) -> plt.Axes: """Plot the Shapley values using the provided plot type. @@ -334,7 +334,7 @@ def plot( plot_type: The type of plot to be created. data: See :meth:`explain`. show: Boolean flag determining if the plot is to be rendered. - explanation_idx: Positional index of the data point that should be + explanation_index: Positional index of the data point that should be explained. Only relevant for plot types that can only handle a single data point. **kwargs: Additional keyword arguments passed to the plot function. @@ -360,15 +360,15 @@ def plot( plot_func = getattr(shap.plots, plot_type) # Handle plot types that only explain a single data point - if plot_type in {"force"}: - if explanation_idx is None: + if plot_type == "force": + if explanation_index is None: warnings.warn( - f"When using plot type '{plot_type}', a 'explanation_idx' must be " - f"chosen to identify a single data point that should be explained. " - f"Choosing the first measurement at position 0." + f"When using plot type '{plot_type}', an 'explanation_index' must " + f"be chosen to identify a single data point that should be " + f"explained. Choosing the first entry at position 0." ) - explanation_idx = 0 - toplot = self.explain(data.iloc[[explanation_idx]]) + explanation_index = 0 + toplot = self.explain(data.iloc[[explanation_index]]) kwargs["matplotlib"] = True else: toplot = self.explain(data)