From 0c8e94566232f096ddb35b27424e31c1cbd19ac3 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Sun, 29 Sep 2024 20:26:14 +0200 Subject: [PATCH 01/92] Optional import of shap package. --- README.md | 1 + baybe/_optional/info.py | 1 + baybe/_optional/shap.py | 17 +++++++++++++++++ mypy.ini | 3 +++ pyproject.toml | 4 ++++ tox.ini | 2 +- 6 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 baybe/_optional/shap.py diff --git a/README.md b/README.md index 91383dfd1..562147ac5 100644 --- a/README.md +++ b/README.md @@ -296,6 +296,7 @@ The available groups are: - `mypy`: Required for static type checking. - `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). - `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/) +- `shap`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/) - `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. - `test`: Required for running the tests. - `dev`: All of the above plus `tox` and `pip-audit`. For code contributors. diff --git a/baybe/_optional/info.py b/baybe/_optional/info.py index e725b4799..45efd69a0 100644 --- a/baybe/_optional/info.py +++ b/baybe/_optional/info.py @@ -28,6 +28,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 MORDRED_INSTALLED = find_spec("mordred") is not None ONNX_INSTALLED = find_spec("onnxruntime") is not None POLARS_INSTALLED = find_spec("polars") is not None + SHAP_INSTALLED = find_spec("shap") is not None PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None RDKIT_INSTALLED = find_spec("rdkit") is not None diff --git a/baybe/_optional/shap.py b/baybe/_optional/shap.py new file mode 100644 index 000000000..0d7916df9 --- /dev/null +++ b/baybe/_optional/shap.py @@ -0,0 +1,17 @@ +"""Optional SHAP import.""" + +from baybe.exceptions import OptionalImportError + +try: + import shap +except ModuleNotFoundError as ex: + raise OptionalImportError( + "Feature importance ranking functionality is unavailable " + "because 'shap' is not installed. " + "Consider installing BayBE with 'polars' dependency, e.g. via " + "`pip install baybe[shap]`." + ) from ex + +__all__ = [ + "shap", +] diff --git a/mypy.ini b/mypy.ini index 7a9737aea..c285a4e8c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -62,3 +62,6 @@ ignore_missing_imports = True [mypy-polars] ignore_missing_imports = True + +[mypy-shap] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 366f5d451..ced3e43c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,10 @@ polars = [ "polars[pyarrow]>=0.19.19,<2", ] +shap = [ + "shap>=0.46.0", +] + simulation = [ "xyzpy>=1.2.1", ] diff --git a/tox.ini b/tox.ini index fe2ced727..79ab98ec1 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ isolated_build = True [testenv:fulltest,fulltest-py{310,311,312}] description = Run PyTest with all extra functionality -extras = chem,examples,lint,onnx,polars,simulation,test +extras = chem,examples,lint,onnx,polars,shap,simulation,test passenv = CI BAYBE_NUMPY_USE_SINGLE_PRECISION From cbe2e825171ced8c130dd4bac5a26bb37f7786a6 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Sun, 29 Sep 2024 21:56:17 +0200 Subject: [PATCH 02/92] 1st implementation of SHAP utilities in experimental space and with procedural approach. --- baybe/_optional/shap.py | 5 ++-- baybe/utils/diagnostics.py | 47 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 baybe/utils/diagnostics.py diff --git a/baybe/_optional/shap.py b/baybe/_optional/shap.py index 0d7916df9..df8130879 100644 --- a/baybe/_optional/shap.py +++ b/baybe/_optional/shap.py @@ -6,9 +6,8 @@ import shap except ModuleNotFoundError as ex: raise OptionalImportError( - "Feature importance ranking functionality is unavailable " - "because 'shap' is not installed. " - "Consider installing BayBE with 'polars' dependency, e.g. via " + "Shapley functionality is unavailable because 'shap' is not installed. " + "Consider installing BayBE with 'shap' dependency, e.g. via " "`pip install baybe[shap]`." ) from ex diff --git a/baybe/utils/diagnostics.py b/baybe/utils/diagnostics.py new file mode 100644 index 000000000..5c2eea777 --- /dev/null +++ b/baybe/utils/diagnostics.py @@ -0,0 +1,47 @@ +"""Diagnostics utilities.""" + +import pandas as pd +import shap +import torch + +from baybe import Campaign + + +def shapley_values( + campaign: Campaign, explainer: callable = shap.KernelExplainer +) -> shap.Explanation: + """Compute the Shapley values for the provided campaign and data. + + Args: + campaign: The campaign to be explained. + explainer: The explainer to be used. Default is shap.KernelExplainer. + + Returns: + The Shapley values for the provided campaign and data. + + Raises: + ValueError: If no measurements have been provided yet. + """ + if campaign.measurements.empty: + raise ValueError("No measurements have been provided yet.") + + data = campaign.measurements[[p.name for p in campaign.parameters]] + + def model(x): + df = pd.DataFrame(x, columns=data.columns) + + output = campaign.get_surrogate().posterior(df).mean + + if isinstance(output, torch.Tensor): + return output.detach().numpy() + + return output + + explain = explainer(model, data) + shap_values = explain(data) + return shap_values[:, :, 0] + + +def shapley_plot_beeswarm(explaination: shap.Explainer) -> None: + """Plot the Shapley values using a beeswarm plot.""" + shap.plots.beeswarm(explaination) From 2597fd4d731e83a3b577b343d643014957fd0123 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Tue, 1 Oct 2024 23:49:04 +0200 Subject: [PATCH 03/92] Implementation option to perform SHAP either in computational or experimental searchspace representation. --- baybe/utils/diagnostics.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/baybe/utils/diagnostics.py b/baybe/utils/diagnostics.py index 5c2eea777..1bb307aa4 100644 --- a/baybe/utils/diagnostics.py +++ b/baybe/utils/diagnostics.py @@ -5,16 +5,22 @@ import torch from baybe import Campaign +from baybe.utils.dataframe import to_tensor def shapley_values( - campaign: Campaign, explainer: callable = shap.KernelExplainer + campaign: Campaign, + explainer: callable = shap.KernelExplainer, + computational_representation: bool = False, ) -> shap.Explanation: """Compute the Shapley values for the provided campaign and data. Args: campaign: The campaign to be explained. explainer: The explainer to be used. Default is shap.KernelExplainer. + computational_representation: Whether to compute the Shapley values + in computational or experimental searchspace. + Default is False. Returns: The Shapley values for the provided campaign and data. @@ -27,15 +33,31 @@ def shapley_values( data = campaign.measurements[[p.name for p in campaign.parameters]] - def model(x): - df = pd.DataFrame(x, columns=data.columns) + if computational_representation: + data = campaign.searchspace.transform(data) - output = campaign.get_surrogate().posterior(df).mean + def model(x): + df = pd.DataFrame(x, columns=data.columns) - if isinstance(output, torch.Tensor): - return output.detach().numpy() + tensor = to_tensor(df) - return output + output = campaign.get_surrogate()._posterior_comp(tensor).mean + + if isinstance(output, torch.Tensor): + return output.detach().numpy() + + return output + else: + + def model(x): + df = pd.DataFrame(x, columns=data.columns) + + output = campaign.get_surrogate().posterior(df).mean + + if isinstance(output, torch.Tensor): + return output.detach().numpy() + + return output explain = explainer(model, data) shap_values = explain(data) From bc1203e1a89a9d5dc767b4e3c6ef9c5972b68411 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Thu, 3 Oct 2024 17:14:30 +0200 Subject: [PATCH 04/92] SHAP package implementation in diagnostics utility, complete tests and plotting methods. From b348b46c045327f84f9bb9eef312d1227e1f7c59 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Thu, 3 Oct 2024 17:29:51 +0200 Subject: [PATCH 05/92] Tests for explainer utilities and generalization for all explainers in SHAP package. --- tests/utils/test_diagnostics.py | 95 +++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/utils/test_diagnostics.py diff --git a/tests/utils/test_diagnostics.py b/tests/utils/test_diagnostics.py new file mode 100644 index 000000000..40ce6b446 --- /dev/null +++ b/tests/utils/test_diagnostics.py @@ -0,0 +1,95 @@ +"""Tests for diagnostic utilities.""" + +import pandas as pd +import pytest +import shap + +import baybe.utils.diagnostics as diag +from baybe import Campaign +from baybe.objective import SingleTargetObjective +from baybe.parameters import ( + NumericalContinuousParameter, + NumericalDiscreteParameter, + SubstanceParameter, +) +from baybe.searchspace import SearchSpace +from baybe.targets import NumericalTarget + + +def test_shapley_values_no_measurements(): + """A campaign without measurements raises an error.""" + parameters = [NumericalContinuousParameter("param1", bounds=(0, 1))] + searchspace = SearchSpace.from_product(parameters=parameters) + target = NumericalTarget(name="y_1", mode="MAX") + objective = SingleTargetObjective(target) + campaign = Campaign(searchspace, objective) + + with pytest.raises(ValueError, match="No measurements have been provided yet."): + diag.explain(campaign) + + +def test_shapley_with_measurements(): + """Test the explain functionalities with measurements for a hybrid space.""" + parameters = [ + NumericalDiscreteParameter("NumDisc", values=(0, 1, 2)), + NumericalContinuousParameter("NumCont", bounds=(2, 3)), + SubstanceParameter( + name="Molecules", + data={ + "TIPS-TAP": "C12=CC=CC=C1N=C3C(C=C(N=C(C=CC=C4)C4=N5)C5=C3)=N2", + "Pyrene": "C1(C=CC2)=C(C2=CC=C3CC=C4)C3=C4C=C1", + }, + encoding="MORDRED", + ), + ] + searchspace = SearchSpace.from_product(parameters=parameters) + target = NumericalTarget(name="y_1", mode="MAX") + objective = SingleTargetObjective(target=target) + campaign = Campaign(searchspace, objective) + + campaign.add_measurements( + pd.DataFrame( + { + "NumDisc": [0, 2], + "NumCont": [2.2, 2.8], + "Molecules": ["Pyrene", "TIPS-TAP"], + "y_1": [0.5, 0.7], + } + ) + ) + campaign.recommend(3) + + """Test the default explainer in experimental representation.""" + shap_val = diag.explain(campaign) + assert isinstance(shap_val, shap.Explanation) + + """Test the default explainer in computational representation.""" + shap_val_comp = diag.explain(campaign, computational_representation=True) + assert isinstance(shap_val_comp, shap.Explanation) + + """Test the MAPLE explainer in experimental representation.""" + maple_explainer = diag.explainer( + campaign, + computational_representation=True, + explainer=shap.explainers.other.Maple, + ) + assert isinstance(maple_explainer, shap.explainers.other._maple.Maple) + + """Ensure that an error is raised if the data + to be explained has a different number of parameters.""" + df = pd.DataFrame( + { + "NumDisc": [0, 2], + "NumCont": [2.2, 2.8], + "Molecules": ["Pyrene", "TIPS-TAP"], + "ExtraParam": [0, 1], + } + ) + with pytest.raises( + ValueError, + match=( + "The provided data does not have the same " + "amount of parameters as the shap explainer background." + ), + ): + diag.explain(campaign, data=df) From ae20322bc10f06b5c2e202d591ce4e4af04b4970 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Thu, 3 Oct 2024 23:44:02 +0200 Subject: [PATCH 06/92] Implemented plotting with non-shap attributions. --- baybe/utils/diagnostics.py | 173 +++- botorch_analytical_dark.svg | 1360 +++++++++++++++++++++++++++++++ botorch_analytical_light.svg | 1360 +++++++++++++++++++++++++++++++ pytest.ini | 2 + tests/utils/test_diagnostics.py | 27 +- 5 files changed, 2891 insertions(+), 31 deletions(-) create mode 100644 botorch_analytical_dark.svg create mode 100644 botorch_analytical_light.svg diff --git a/baybe/utils/diagnostics.py b/baybe/utils/diagnostics.py index 1bb307aa4..a709978af 100644 --- a/baybe/utils/diagnostics.py +++ b/baybe/utils/diagnostics.py @@ -1,29 +1,34 @@ """Diagnostics utilities.""" +import numbers +import warnings + +import numpy as np import pandas as pd import shap -import torch from baybe import Campaign from baybe.utils.dataframe import to_tensor -def shapley_values( +def explainer( campaign: Campaign, - explainer: callable = shap.KernelExplainer, + explainer_class: shap.Explainer = shap.KernelExplainer, computational_representation: bool = False, -) -> shap.Explanation: - """Compute the Shapley values for the provided campaign and data. + **kwargs, +) -> shap.Explainer: + """Create an explainer for the provided campaign. Args: campaign: The campaign to be explained. - explainer: The explainer to be used. Default is shap.KernelExplainer. + explainer_class: The explainer to be used. Default is shap.KernelExplainer. computational_representation: Whether to compute the Shapley values in computational or experimental searchspace. Default is False. + **kwargs: Additional keyword arguments to be passed to the explainer. Returns: - The Shapley values for the provided campaign and data. + The explainer for the provided campaign. Raises: ValueError: If no measurements have been provided yet. @@ -31,39 +36,157 @@ def shapley_values( if campaign.measurements.empty: raise ValueError("No measurements have been provided yet.") - data = campaign.measurements[[p.name for p in campaign.parameters]] + data = campaign.measurements[[p.name for p in campaign.parameters]].copy() if computational_representation: data = campaign.searchspace.transform(data) def model(x): - df = pd.DataFrame(x, columns=data.columns) - - tensor = to_tensor(df) - + tensor = to_tensor(x) output = campaign.get_surrogate()._posterior_comp(tensor).mean - if isinstance(output, torch.Tensor): - return output.detach().numpy() - - return output + return output.detach().numpy() else: def model(x): df = pd.DataFrame(x, columns=data.columns) - output = campaign.get_surrogate().posterior(df).mean - if isinstance(output, torch.Tensor): - return output.detach().numpy() + return output.detach().numpy() + + shap_explainer = explainer_class(model, data, **kwargs) + return shap_explainer + + +def explanation( + campaign: Campaign, + data: np.ndarray = None, + explainer_class: shap.Explainer = shap.KernelExplainer, + computational_representation: bool = False, + **kwargs, +) -> shap.Explanation: + """Compute the Shapley values for the provided campaign and data. + + Args: + campaign: The campaign to be explained. + data: The data to be explained. + Default is None which uses the campaign's measurements. + explainer_class: The explainer to be used. + Default is shap.KernelExplainer. + computational_representation: Whether to compute the Shapley values + in computational or experimental searchspace. + Default is False. + **kwargs: Additional keyword arguments to be passed to the explainer. + + Returns: + The Shapley values for the provided campaign. + + Raises: + ValueError: If the provided data does not have the same amount of parameters + as previously provided to the explainer. + """ + is_shap_explainer = not explainer_class.__module__.startswith( + "shap.explainers.other." + ) + + if not is_shap_explainer and not computational_representation: + raise ValueError( + "Experimental representation is not " + "supported for non-Kernel SHAP explainer." + ) + + explainer_obj = explainer( + campaign, + explainer_class=explainer_class, + computational_representation=computational_representation, + **kwargs, + ) + + if data is None: + if isinstance(explainer_obj.data, np.ndarray): + data = explainer_obj.data + else: + data = explainer_obj.data.data + elif computational_representation: + data = campaign.searchspace.transform(data) - return output + if not is_shap_explainer: + """Return attributions for non-SHAP explainers.""" + attributions = explainer_obj.attributions(data)[0] + if computational_representation: + feature_names = campaign.searchspace.comp_rep_columns + else: + feature_names = campaign.searchspace.parameter_names + explanations = shap.Explanation( + values=attributions, + base_values=data, + data=data, + ) + explanations.feature_names = list(feature_names) + return explanations + + if data.shape[1] != explainer_obj.data.data.shape[1]: + raise ValueError( + "The provided data does not have the same amount " + "of parameters as the shap explainer background." + ) + else: + shap_explanations = explainer_obj(data)[:, :, 0] - explain = explainer(model, data) - shap_values = explain(data) - return shap_values[:, :, 0] + return shap_explanations -def shapley_plot_beeswarm(explaination: shap.Explainer) -> None: +def shap_plot_beeswarm(explanation: shap.Explanation, **kwargs) -> None: """Plot the Shapley values using a beeswarm plot.""" - shap.plots.beeswarm(explaination) + shap.plots.beeswarm(explanation, **kwargs) + + +def shap_plot_waterfall(explanation: shap.Explanation, **kwargs) -> None: + """Plot the Shapley values using a waterfall plot.""" + shap.plots.waterfall(explanation, **kwargs) + + +def shap_plot_bar(explanation: shap.Explanation, **kwargs) -> None: + """Plot the Shapley values using a bar plot.""" + shap.plots.bar(explanation, **kwargs) + + +def shap_plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None: + """Plot the Shapley values using a scatter plot while leaving out string values. + + Args: + explanation: The Shapley values to be plotted. + **kwargs: Additional keyword arguments to be passed to the scatter plot. + + Raises: + ValueError: If the provided explanation object does not match the + required types. + """ + if isinstance(explanation, memoryview): + data = explanation.obj + elif isinstance(explanation, shap.Explanation): + data = explanation.data.data.obj + else: + raise ValueError("The provided explanation argument is not of a valid type.") + + def is_not_numeric_column(col): + return np.array([not isinstance(v, numbers.Number) for v in col]).any() + + if data.ndim == 1: + if is_not_numeric_column(data): + warnings.warn( + "Cannot plot scatter plot for the provided " + "explanation as it contains non-numeric values." + ) + else: + shap.plots.scatter(explanation, **kwargs) + else: + for i in range(data.shape[1]): + if is_not_numeric_column(data[:, i]): + warnings.warn( + "Cannot plot scatter plot for column " + f"'{explanation.feature_names[i]}' " + "as it contains non-numeric values." + ) + else: + shap.plots.scatter(explanation[:, i], **kwargs) diff --git a/botorch_analytical_dark.svg b/botorch_analytical_dark.svg new file mode 100644 index 000000000..175e82da0 --- /dev/null +++ b/botorch_analytical_dark.svg @@ -0,0 +1,1360 @@ + + + + + + + + 2024-10-03T11:44:51.512451 + image/svg+xml + + + Matplotlib v3.9.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/botorch_analytical_light.svg b/botorch_analytical_light.svg new file mode 100644 index 000000000..5ccc90582 --- /dev/null +++ b/botorch_analytical_light.svg @@ -0,0 +1,1360 @@ + + + + + + + + 2024-10-03T11:44:51.586128 + image/svg+xml + + + Matplotlib v3.9.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/pytest.ini b/pytest.ini index c993cc465..283221189 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,6 +10,8 @@ addopts = --ignore=baybe/_optional --ignore=baybe/utils/chemistry.py --ignore=tests/simulate_telemetry.py + --ignore=baybe/utils/diagnostics.py + --ignore=tests/utils/test_diagnostics.py testpaths = baybe tests \ No newline at end of file diff --git a/tests/utils/test_diagnostics.py b/tests/utils/test_diagnostics.py index 40ce6b446..7636f38b3 100644 --- a/tests/utils/test_diagnostics.py +++ b/tests/utils/test_diagnostics.py @@ -25,7 +25,7 @@ def test_shapley_values_no_measurements(): campaign = Campaign(searchspace, objective) with pytest.raises(ValueError, match="No measurements have been provided yet."): - diag.explain(campaign) + diag.explanation(campaign) def test_shapley_with_measurements(): @@ -60,21 +60,36 @@ def test_shapley_with_measurements(): campaign.recommend(3) """Test the default explainer in experimental representation.""" - shap_val = diag.explain(campaign) + shap_val = diag.explanation(campaign) assert isinstance(shap_val, shap.Explanation) """Test the default explainer in computational representation.""" - shap_val_comp = diag.explain(campaign, computational_representation=True) + shap_val_comp = diag.explanation(campaign, computational_representation=True) assert isinstance(shap_val_comp, shap.Explanation) - """Test the MAPLE explainer in experimental representation.""" + """Test the MAPLE explainer in computational representation.""" maple_explainer = diag.explainer( campaign, computational_representation=True, - explainer=shap.explainers.other.Maple, + explainer_class=shap.explainers.other.Maple, ) assert isinstance(maple_explainer, shap.explainers.other._maple.Maple) + """Ensure that an error is raised if non-computational representation + is used with a non-Kernel SHAP explainer.""" + with pytest.raises( + ValueError, + match=( + "Experimental representation is not supported " + "for non-Kernel SHAP explainer." + ), + ): + diag.explanation( + campaign, + computational_representation=False, + explainer_class=shap.explainers.other.Maple, + ) + """Ensure that an error is raised if the data to be explained has a different number of parameters.""" df = pd.DataFrame( @@ -92,4 +107,4 @@ def test_shapley_with_measurements(): "amount of parameters as the shap explainer background." ), ): - diag.explain(campaign, data=df) + diag.explanation(campaign, data=df) From de9d1e9e21685f5c020a759d510ca30d7eab8532 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Fri, 4 Oct 2024 13:40:27 +0200 Subject: [PATCH 07/92] Refactored diagnostics test and optimized handling of maple explainers. --- baybe/utils/diagnostics.py | 22 +++---- tests/utils/test_diagnostics.py | 101 ++++++++++++++++++-------------- 2 files changed, 69 insertions(+), 54 deletions(-) diff --git a/baybe/utils/diagnostics.py b/baybe/utils/diagnostics.py index a709978af..ab3b6cdd8 100644 --- a/baybe/utils/diagnostics.py +++ b/baybe/utils/diagnostics.py @@ -112,7 +112,11 @@ def explanation( if not is_shap_explainer: """Return attributions for non-SHAP explainers.""" - attributions = explainer_obj.attributions(data)[0] + if explainer_class.__module__.endswith("maple"): + """Aditional argument for maple to increase comparability to SHAP.""" + attributions = explainer_obj.attributions(data, multiply_by_input=True)[0] + else: + attributions = explainer_obj.attributions(data)[0] if computational_representation: feature_names = campaign.searchspace.comp_rep_columns else: @@ -181,12 +185,10 @@ def is_not_numeric_column(col): else: shap.plots.scatter(explanation, **kwargs) else: - for i in range(data.shape[1]): - if is_not_numeric_column(data[:, i]): - warnings.warn( - "Cannot plot scatter plot for column " - f"'{explanation.feature_names[i]}' " - "as it contains non-numeric values." - ) - else: - shap.plots.scatter(explanation[:, i], **kwargs) + number_enum = [i for i, x in enumerate(data[1]) if not isinstance(x, str)] + if len(number_enum) < len(explanation.feature_names): + warnings.warn( + "Cannot plot SHAP scatter plot for all " + "parameters as some contain non-numeric values." + ) + shap.plots.scatter(explanation[:, number_enum], **kwargs) diff --git a/tests/utils/test_diagnostics.py b/tests/utils/test_diagnostics.py index 7636f38b3..06cb8c488 100644 --- a/tests/utils/test_diagnostics.py +++ b/tests/utils/test_diagnostics.py @@ -16,27 +16,16 @@ from baybe.targets import NumericalTarget -def test_shapley_values_no_measurements(): - """A campaign without measurements raises an error.""" - parameters = [NumericalContinuousParameter("param1", bounds=(0, 1))] - searchspace = SearchSpace.from_product(parameters=parameters) - target = NumericalTarget(name="y_1", mode="MAX") - objective = SingleTargetObjective(target) - campaign = Campaign(searchspace, objective) - - with pytest.raises(ValueError, match="No measurements have been provided yet."): - diag.explanation(campaign) - - -def test_shapley_with_measurements(): - """Test the explain functionalities with measurements for a hybrid space.""" +@pytest.fixture +def diagnostics_campaign(): + """Create a campaign with a hybrid space including substances.""" parameters = [ NumericalDiscreteParameter("NumDisc", values=(0, 1, 2)), NumericalContinuousParameter("NumCont", bounds=(2, 3)), SubstanceParameter( name="Molecules", data={ - "TIPS-TAP": "C12=CC=CC=C1N=C3C(C=C(N=C(C=CC=C4)C4=N5)C5=C3)=N2", + "TAP": "C12=CC=CC=C1N=C3C(C=C(N=C(C=CC=C4)C4=N5)C5=C3)=N2", "Pyrene": "C1(C=CC2)=C(C2=CC=C3CC=C4)C3=C4C=C1", }, encoding="MORDRED", @@ -46,49 +35,47 @@ def test_shapley_with_measurements(): target = NumericalTarget(name="y_1", mode="MAX") objective = SingleTargetObjective(target=target) campaign = Campaign(searchspace, objective) + return campaign - campaign.add_measurements( + +@pytest.fixture +def diagnostics_campaign_activated(diagnostics_campaign): + """Create an activated campaign with a hybrid space including substances. + + Measurements were added and first recommendations were made. + """ + diagnostics_campaign.add_measurements( pd.DataFrame( { "NumDisc": [0, 2], "NumCont": [2.2, 2.8], - "Molecules": ["Pyrene", "TIPS-TAP"], + "Molecules": ["Pyrene", "TAP"], "y_1": [0.5, 0.7], } ) ) - campaign.recommend(3) + diagnostics_campaign.recommend(3) + return diagnostics_campaign + + +def test_shapley_values_no_measurements(diagnostics_campaign): + """A campaign without measurements raises an error.""" + with pytest.raises(ValueError, match="No measurements have been provided yet."): + diag.explanation(diagnostics_campaign) + +def test_shapley_with_measurements(diagnostics_campaign_activated): + """Test the explain functionalities with measurements.""" """Test the default explainer in experimental representation.""" - shap_val = diag.explanation(campaign) + shap_val = diag.explanation(diagnostics_campaign_activated) assert isinstance(shap_val, shap.Explanation) """Test the default explainer in computational representation.""" - shap_val_comp = diag.explanation(campaign, computational_representation=True) - assert isinstance(shap_val_comp, shap.Explanation) - - """Test the MAPLE explainer in computational representation.""" - maple_explainer = diag.explainer( - campaign, + shap_val_comp = diag.explanation( + diagnostics_campaign_activated, computational_representation=True, - explainer_class=shap.explainers.other.Maple, ) - assert isinstance(maple_explainer, shap.explainers.other._maple.Maple) - - """Ensure that an error is raised if non-computational representation - is used with a non-Kernel SHAP explainer.""" - with pytest.raises( - ValueError, - match=( - "Experimental representation is not supported " - "for non-Kernel SHAP explainer." - ), - ): - diag.explanation( - campaign, - computational_representation=False, - explainer_class=shap.explainers.other.Maple, - ) + assert isinstance(shap_val_comp, shap.Explanation) """Ensure that an error is raised if the data to be explained has a different number of parameters.""" @@ -96,7 +83,7 @@ def test_shapley_with_measurements(): { "NumDisc": [0, 2], "NumCont": [2.2, 2.8], - "Molecules": ["Pyrene", "TIPS-TAP"], + "Molecules": ["Pyrene", "TAP"], "ExtraParam": [0, 1], } ) @@ -107,4 +94,30 @@ def test_shapley_with_measurements(): "amount of parameters as the shap explainer background." ), ): - diag.explanation(campaign, data=df) + diag.explanation(diagnostics_campaign_activated, data=df) + + +def test_non_shapley_explainers(diagnostics_campaign_activated): + """Test the explain functionalities with the non-SHAP explainer MAPLE.""" + """Ensure that an error is raised if non-computational representation + is used with a non-Kernel SHAP explainer.""" + with pytest.raises( + ValueError, + match=( + "Experimental representation is not supported " + "for non-Kernel SHAP explainer." + ), + ): + diag.explanation( + diagnostics_campaign_activated, + computational_representation=False, + explainer_class=shap.explainers.other.Maple, + ) + + """Test the MAPLE explainer in computational representation.""" + maple_explainer = diag.explainer( + diagnostics_campaign_activated, + computational_representation=True, + explainer_class=shap.explainers.other.Maple, + ) + assert isinstance(maple_explainer, shap.explainers.other._maple.Maple) From e183957ddd795735d715cbba42e7256e83bc1d8e Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Fri, 4 Oct 2024 13:47:59 +0200 Subject: [PATCH 08/92] Shortened plotting method names. --- baybe/utils/diagnostics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/baybe/utils/diagnostics.py b/baybe/utils/diagnostics.py index ab3b6cdd8..6386da638 100644 --- a/baybe/utils/diagnostics.py +++ b/baybe/utils/diagnostics.py @@ -113,7 +113,7 @@ def explanation( if not is_shap_explainer: """Return attributions for non-SHAP explainers.""" if explainer_class.__module__.endswith("maple"): - """Aditional argument for maple to increase comparability to SHAP.""" + """Additional argument for maple to increase comparability to SHAP.""" attributions = explainer_obj.attributions(data, multiply_by_input=True)[0] else: attributions = explainer_obj.attributions(data)[0] @@ -140,22 +140,22 @@ def explanation( return shap_explanations -def shap_plot_beeswarm(explanation: shap.Explanation, **kwargs) -> None: +def plot_beeswarm(explanation: shap.Explanation, **kwargs) -> None: """Plot the Shapley values using a beeswarm plot.""" shap.plots.beeswarm(explanation, **kwargs) -def shap_plot_waterfall(explanation: shap.Explanation, **kwargs) -> None: +def plot_waterfall(explanation: shap.Explanation, **kwargs) -> None: """Plot the Shapley values using a waterfall plot.""" shap.plots.waterfall(explanation, **kwargs) -def shap_plot_bar(explanation: shap.Explanation, **kwargs) -> None: +def plot_bar(explanation: shap.Explanation, **kwargs) -> None: """Plot the Shapley values using a bar plot.""" shap.plots.bar(explanation, **kwargs) -def shap_plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None: +def plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None: """Plot the Shapley values using a scatter plot while leaving out string values. Args: From c389ac10858e857a905b7f51bca8bcd7c80d45c4 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Fri, 4 Oct 2024 14:23:34 +0200 Subject: [PATCH 09/92] Cleanup for PR --- botorch_analytical_dark.svg | 1360 ---------------------------------- botorch_analytical_light.svg | 1360 ---------------------------------- 2 files changed, 2720 deletions(-) delete mode 100644 botorch_analytical_dark.svg delete mode 100644 botorch_analytical_light.svg diff --git a/botorch_analytical_dark.svg b/botorch_analytical_dark.svg deleted file mode 100644 index 175e82da0..000000000 --- a/botorch_analytical_dark.svg +++ /dev/null @@ -1,1360 +0,0 @@ - - - - - - - - 2024-10-03T11:44:51.512451 - image/svg+xml - - - Matplotlib v3.9.1, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/botorch_analytical_light.svg b/botorch_analytical_light.svg deleted file mode 100644 index 5ccc90582..000000000 --- a/botorch_analytical_light.svg +++ /dev/null @@ -1,1360 +0,0 @@ - - - - - - - - 2024-10-03T11:44:51.586128 - image/svg+xml - - - Matplotlib v3.9.1, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - From 55e723c012040bc3cbceba78516fc52c4dd56a9b Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Wed, 23 Oct 2024 22:57:51 +0200 Subject: [PATCH 10/92] Renamed diangostics package, enabled optional shap import --- README.md | 2 +- baybe/_optional/diagnostics.py | 16 ++++++++++++++++ baybe/_optional/info.py | 2 +- baybe/_optional/shap.py | 16 ---------------- baybe/utils/diagnostics.py | 2 +- mypy.ini | 2 +- pyproject.toml | 2 +- tox.ini | 2 +- 8 files changed, 22 insertions(+), 22 deletions(-) create mode 100644 baybe/_optional/diagnostics.py delete mode 100644 baybe/_optional/shap.py diff --git a/README.md b/README.md index 562147ac5..e9060112c 100644 --- a/README.md +++ b/README.md @@ -296,7 +296,7 @@ The available groups are: - `mypy`: Required for static type checking. - `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). - `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/) -- `shap`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/) +- `diagnostics`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/) - `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. - `test`: Required for running the tests. - `dev`: All of the above plus `tox` and `pip-audit`. For code contributors. diff --git a/baybe/_optional/diagnostics.py b/baybe/_optional/diagnostics.py new file mode 100644 index 000000000..08138f657 --- /dev/null +++ b/baybe/_optional/diagnostics.py @@ -0,0 +1,16 @@ +"""Optional import for diagnostics utilities.""" + +from baybe.exceptions import OptionalImportError + +try: + import shap +except ModuleNotFoundError as ex: + raise OptionalImportError( + "Explainer functionality is unavailable because 'diagnostics' is not installed." + " Consider installing BayBE with 'diagnostics' dependency, e.g. via " + "`pip install baybe[diagnostics]`." + ) from ex + +__all__ = [ + "shap", +] diff --git a/baybe/_optional/info.py b/baybe/_optional/info.py index 45efd69a0..33b189130 100644 --- a/baybe/_optional/info.py +++ b/baybe/_optional/info.py @@ -28,7 +28,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 MORDRED_INSTALLED = find_spec("mordred") is not None ONNX_INSTALLED = find_spec("onnxruntime") is not None POLARS_INSTALLED = find_spec("polars") is not None - SHAP_INSTALLED = find_spec("shap") is not None + DIAGNOSTICS_INSTALLED = find_spec("diagnostics") is not None PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None RDKIT_INSTALLED = find_spec("rdkit") is not None diff --git a/baybe/_optional/shap.py b/baybe/_optional/shap.py deleted file mode 100644 index df8130879..000000000 --- a/baybe/_optional/shap.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Optional SHAP import.""" - -from baybe.exceptions import OptionalImportError - -try: - import shap -except ModuleNotFoundError as ex: - raise OptionalImportError( - "Shapley functionality is unavailable because 'shap' is not installed. " - "Consider installing BayBE with 'shap' dependency, e.g. via " - "`pip install baybe[shap]`." - ) from ex - -__all__ = [ - "shap", -] diff --git a/baybe/utils/diagnostics.py b/baybe/utils/diagnostics.py index 6386da638..a1e1a4e62 100644 --- a/baybe/utils/diagnostics.py +++ b/baybe/utils/diagnostics.py @@ -5,9 +5,9 @@ import numpy as np import pandas as pd -import shap from baybe import Campaign +from baybe._optional.diagnostics import shap from baybe.utils.dataframe import to_tensor diff --git a/mypy.ini b/mypy.ini index c285a4e8c..588700def 100644 --- a/mypy.ini +++ b/mypy.ini @@ -63,5 +63,5 @@ ignore_missing_imports = True [mypy-polars] ignore_missing_imports = True -[mypy-shap] +[mypy-diagnostics] ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index ced3e43c2..2e03e2af2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,7 @@ polars = [ "polars[pyarrow]>=0.19.19,<2", ] -shap = [ +diagnostics = [ "shap>=0.46.0", ] diff --git a/tox.ini b/tox.ini index 79ab98ec1..1f90161de 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ isolated_build = True [testenv:fulltest,fulltest-py{310,311,312}] description = Run PyTest with all extra functionality -extras = chem,examples,lint,onnx,polars,shap,simulation,test +extras = chem,examples,lint,onnx,polars,diagnostics,simulation,test passenv = CI BAYBE_NUMPY_USE_SINGLE_PRECISION From ee57008d70b3cf175d344c1c47abb80861d06c6c Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Sun, 27 Oct 2024 21:02:41 +0100 Subject: [PATCH 11/92] Refactoring of test_diagnostics.py --- CHANGELOG.md | 4 + CONTRIBUTORS.md | 4 +- tests/utils/test_diagnostics.py | 186 +++++++++++++++----------------- 3 files changed, 95 insertions(+), 99 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fbf6bddd..cb238f50b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added +- Added SHAP analysis within the new `diagnostics` package. + ## [0.11.1] - 2024-10-01 ### Added - Continuous linear constraints have been consolidated in the new diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 05e9796ea..140817a83 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -25,4 +25,6 @@ - Di Jin (Merck Life Science KGaA, Darmstadt, Germany):\ Cardinality constraints - Julian Streibel (Merck Life Science KGaA, Darmstadt, Germany):\ - Bernoulli multi-armed bandit and Thompson sampling \ No newline at end of file + Bernoulli multi-armed bandit and Thompson sampling +- Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dรผbendorf, Switzerland):\ + SHAP explainers for diagnoatics \ No newline at end of file diff --git a/tests/utils/test_diagnostics.py b/tests/utils/test_diagnostics.py index 06cb8c488..eed039371 100644 --- a/tests/utils/test_diagnostics.py +++ b/tests/utils/test_diagnostics.py @@ -1,92 +1,72 @@ """Tests for diagnostic utilities.""" +import inspect + import pandas as pd import pytest -import shap import baybe.utils.diagnostics as diag -from baybe import Campaign -from baybe.objective import SingleTargetObjective -from baybe.parameters import ( - NumericalContinuousParameter, - NumericalDiscreteParameter, - SubstanceParameter, +from baybe._optional.diagnostics import shap +from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.searchspace import SearchSpaceType +from baybe.utils.basic import get_subclasses +from tests.conftest import run_iterations + + +def has_required_init_parameters(cls): + """Helpfer function checks if initializer has required standard parameters.""" + required_parameters = ["self", "model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == required_parameters + + +valid_explainers = [ + getattr(shap.explainers.other, cls_name) + for cls_name in shap.explainers.other.__all__ + if has_required_init_parameters(getattr(shap.explainers.other, cls_name)) +] + +valid_hybrid_bayesian_recommenders = [ + TwoPhaseMetaRecommender(recommender=cls()) + for cls in get_subclasses(BayesianRecommender) + if cls.compatibility == SearchSpaceType.HYBRID +] + + +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], ) -from baybe.searchspace import SearchSpace -from baybe.targets import NumericalTarget - - -@pytest.fixture -def diagnostics_campaign(): - """Create a campaign with a hybrid space including substances.""" - parameters = [ - NumericalDiscreteParameter("NumDisc", values=(0, 1, 2)), - NumericalContinuousParameter("NumCont", bounds=(2, 3)), - SubstanceParameter( - name="Molecules", - data={ - "TAP": "C12=CC=CC=C1N=C3C(C=C(N=C(C=CC=C4)C4=N5)C5=C3)=N2", - "Pyrene": "C1(C=CC2)=C(C2=CC=C3CC=C4)C3=C4C=C1", - }, - encoding="MORDRED", - ), - ] - searchspace = SearchSpace.from_product(parameters=parameters) - target = NumericalTarget(name="y_1", mode="MAX") - objective = SingleTargetObjective(target=target) - campaign = Campaign(searchspace, objective) - return campaign - - -@pytest.fixture -def diagnostics_campaign_activated(diagnostics_campaign): - """Create an activated campaign with a hybrid space including substances. - - Measurements were added and first recommendations were made. - """ - diagnostics_campaign.add_measurements( - pd.DataFrame( - { - "NumDisc": [0, 2], - "NumCont": [2.2, 2.8], - "Molecules": ["Pyrene", "TAP"], - "y_1": [0.5, 0.7], - } - ) - ) - diagnostics_campaign.recommend(3) - return diagnostics_campaign - - -def test_shapley_values_no_measurements(diagnostics_campaign): +def test_shapley_values_no_measurements(campaign): """A campaign without measurements raises an error.""" with pytest.raises(ValueError, match="No measurements have been provided yet."): - diag.explanation(diagnostics_campaign) + diag.explanation(campaign) -def test_shapley_with_measurements(diagnostics_campaign_activated): +@pytest.mark.slow +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], +) +def test_shapley_with_measurements(campaign): """Test the explain functionalities with measurements.""" - """Test the default explainer in experimental representation.""" - shap_val = diag.explanation(diagnostics_campaign_activated) - assert isinstance(shap_val, shap.Explanation) - - """Test the default explainer in computational representation.""" - shap_val_comp = diag.explanation( - diagnostics_campaign_activated, - computational_representation=True, - ) - assert isinstance(shap_val_comp, shap.Explanation) + """Test the default explainer in experimental + and computational representations.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + + for computational_representation in [False, True]: + shap_val = diag.explanation( + campaign, + computational_representation=computational_representation, + ) + assert isinstance(shap_val, shap.Explanation) """Ensure that an error is raised if the data to be explained has a different number of parameters.""" - df = pd.DataFrame( - { - "NumDisc": [0, 2], - "NumCont": [2.2, 2.8], - "Molecules": ["Pyrene", "TAP"], - "ExtraParam": [0, 1], - } - ) + df = pd.DataFrame({"Num_disc_1": [0, 2]}) with pytest.raises( ValueError, match=( @@ -94,30 +74,40 @@ def test_shapley_with_measurements(diagnostics_campaign_activated): "amount of parameters as the shap explainer background." ), ): - diag.explanation(diagnostics_campaign_activated, data=df) + diag.explanation(campaign, data=df) -def test_non_shapley_explainers(diagnostics_campaign_activated): +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], +) +def test_non_shapley_explainers(campaign): """Test the explain functionalities with the non-SHAP explainer MAPLE.""" - """Ensure that an error is raised if non-computational representation - is used with a non-Kernel SHAP explainer.""" - with pytest.raises( - ValueError, - match=( - "Experimental representation is not supported " - "for non-Kernel SHAP explainer." - ), - ): - diag.explanation( - diagnostics_campaign_activated, - computational_representation=False, - explainer_class=shap.explainers.other.Maple, - ) - - """Test the MAPLE explainer in computational representation.""" - maple_explainer = diag.explainer( - diagnostics_campaign_activated, - computational_representation=True, - explainer_class=shap.explainers.other.Maple, - ) - assert isinstance(maple_explainer, shap.explainers.other._maple.Maple) + run_iterations(campaign, n_iterations=2, batch_size=1) + + for explainer_cls in valid_explainers: + try: + """Ensure that an error is raised if non-computational representation + is used with a non-Kernel SHAP explainer.""" + with pytest.raises( + ValueError, + match=( + "Experimental representation is not supported " + "for non-Kernel SHAP explainer." + ), + ): + diag.explanation( + campaign, + computational_representation=False, + explainer_class=explainer_cls, + ) + + """Test the non-SHAP explainer in computational representation.""" + other_explainer = diag.explanation( + campaign, + computational_representation=True, + explainer_class=explainer_cls, + ) + assert isinstance(other_explainer, shap.Explanation) + except NotImplementedError: + pass From 103a5f7360f465533dba948fcaa50185e4c8b00e Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Mon, 28 Oct 2024 15:12:25 +0100 Subject: [PATCH 12/92] Fixed changelog merging error --- CHANGELOG.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62feef412..6818339da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - Added SHAP analysis within the new `diagnostics` package. +- `allow_missing` and `allow_extra` keyword arguments to `Objective.transform` + +### Deprecations +- Passing a dataframe via the `data` argument to `Objective.transform` is no longer + possible. The dataframe must now be passed as positional argument. +- The new `allow_extra` flag is automatically set to `True` in `Objective.transform` + when left unspecified +- `get_transform_parameters` has been replaced with `get_transform_objects` +- Passing a dataframe via the `data` argument to `Target.transform` is no longer + possible. The data must now be passed as a series as first positional argument. + +## [0.11.2] - 2024-10-11 +### Added +- `n_restarts` and `n_raw_samples` keywords to configure continuous optimization + behavior for `BotorchRecommender` +- User guide for utilities +- `mypy` rule expecting explicit `override` markers for method overrides + +### Changed +- Utility `add_fake_results` renamed to `add_fake_measurements` +- Utilities `add_fake_measurements` and `add_parameter_noise` now also return the + dataframe they modified in-place + +### Fixed +- Leftover attrs-decorated classes are garbage collected before the subclass tree is + traversed, avoiding sporadic serialization problems ## [0.11.1] - 2024-10-01 ### Added From ffda991626ccd7f766ee5a1cfed4fd5f4569a457 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 1 Nov 2024 17:06:48 +0100 Subject: [PATCH 13/92] Update pyproject.toml --- pyproject.toml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index adfc790bf..29e79894a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ onnx = [ dev = [ "baybe[chem]", + "baybe[diagnostics]", "baybe[docs]", "baybe[examples]", "baybe[lint]", @@ -94,6 +95,11 @@ dev = [ "uv>=0.3.0", # `uv lock` (for lockfiles) is stable since 0.3.0: https://github.com/astral-sh/uv/issues/2679#event-13950215962 ] +diagnostics = [ + "shap>=0.46.0", + "lime>=0.2.0.1" +] + docs = [ "baybe[examples]", # docs cannot be built without running examples "furo>=2023.09.10", @@ -138,10 +144,6 @@ polars = [ "polars[pyarrow]>=0.19.19,<2", ] -diagnostics = [ - "shap>=0.46.0", -] - simulation = [ "xyzpy>=1.2.1", ] From eaa5c38add49e0468bef45dae2c5285f247abb34 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 1 Nov 2024 19:07:35 +0100 Subject: [PATCH 14/92] Rework import flag --- baybe/_optional/info.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/baybe/_optional/info.py b/baybe/_optional/info.py index 33b189130..db7518aca 100644 --- a/baybe/_optional/info.py +++ b/baybe/_optional/info.py @@ -28,7 +28,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 MORDRED_INSTALLED = find_spec("mordred") is not None ONNX_INSTALLED = find_spec("onnxruntime") is not None POLARS_INSTALLED = find_spec("polars") is not None - DIAGNOSTICS_INSTALLED = find_spec("diagnostics") is not None + SHAP_INSTALLED = find_spec("shap") is not None PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None RDKIT_INSTALLED = find_spec("rdkit") is not None @@ -46,6 +46,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 # Package combinations CHEM_INSTALLED = MORDRED_INSTALLED and RDKIT_INSTALLED +DIAGNOSTICS_INSTALLED = SHAP_INSTALLED LINT_INSTALLED = all( ( FLAKE8_INSTALLED, From 4ca9ffd66188e961a67cc2a62fe9a0402e63afa8 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 1 Nov 2024 19:11:11 +0100 Subject: [PATCH 15/92] Update mypy.ini --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 1188e9e6e..769dcb3dd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -65,5 +65,5 @@ ignore_missing_imports = True [mypy-polars] ignore_missing_imports = True -[mypy-diagnostics] +[mypy-shap.*] ignore_missing_imports = True From 9fddbdd88fcc57f68a9009a9741dbfd8d47c5ab3 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 1 Nov 2024 20:27:41 +0100 Subject: [PATCH 16/92] Rework tests --- tests/utils/test_diagnostics.py | 92 ++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/tests/utils/test_diagnostics.py b/tests/utils/test_diagnostics.py index eed039371..4b2f016f3 100644 --- a/tests/utils/test_diagnostics.py +++ b/tests/utils/test_diagnostics.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from pytest import param import baybe.utils.diagnostics as diag from baybe._optional.diagnostics import shap @@ -14,22 +15,35 @@ from tests.conftest import run_iterations -def has_required_init_parameters(cls): - """Helpfer function checks if initializer has required standard parameters.""" +def _has_required_init_parameters(cls): + """Helper function checks if initializer has required standard parameters.""" required_parameters = ["self", "model", "data"] init_signature = inspect.signature(cls.__init__) parameters = list(init_signature.parameters.keys()) return parameters[:3] == required_parameters -valid_explainers = [ - getattr(shap.explainers.other, cls_name) +non_shap_explainers = [ + param(explainer, id=f"{explainer.__name__}") for cls_name in shap.explainers.other.__all__ - if has_required_init_parameters(getattr(shap.explainers.other, cls_name)) + if _has_required_init_parameters( + explainer := getattr(shap.explainers.other, cls_name) + ) +] + + +shap_explainers = [ + param(explainer, id=f"{explainer.__name__}") + for cls_name in shap.explainers.__all__ + if _has_required_init_parameters(explainer := getattr(shap.explainers, cls_name)) + and all( + x not in explainer.__name__ + for x in ["Tree", "GPU", "Gradient", "Sampling", "Deep"] + ) ] valid_hybrid_bayesian_recommenders = [ - TwoPhaseMetaRecommender(recommender=cls()) + param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") for cls in get_subclasses(BayesianRecommender) if cls.compatibility == SearchSpaceType.HYBRID ] @@ -46,23 +60,26 @@ def test_shapley_values_no_measurements(campaign): @pytest.mark.slow +@pytest.mark.parametrize("explainer_cls", shap_explainers) +@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) @pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) @pytest.mark.parametrize( "parameter_names", [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], + ids=["params1"], ) -def test_shapley_with_measurements(campaign): +def test_shapley_with_measurements(campaign, use_comp_rep, explainer_cls): """Test the explain functionalities with measurements.""" """Test the default explainer in experimental and computational representations.""" run_iterations(campaign, n_iterations=2, batch_size=1) - for computational_representation in [False, True]: - shap_val = diag.explanation( - campaign, - computational_representation=computational_representation, - ) - assert isinstance(shap_val, shap.Explanation) + shap_val = diag.explanation( + campaign, + computational_representation=use_comp_rep, + explainer_class=explainer_cls, + ) + assert isinstance(shap_val, shap.Explanation) """Ensure that an error is raised if the data to be explained has a different number of parameters.""" @@ -77,37 +94,38 @@ def test_shapley_with_measurements(campaign): diag.explanation(campaign, data=df) +@pytest.mark.parametrize("explainer_cls", non_shap_explainers) @pytest.mark.parametrize( "parameter_names", [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], + ids=["params1"], ) -def test_non_shapley_explainers(campaign): +def test_non_shapley_explainers(campaign, explainer_cls): """Test the explain functionalities with the non-SHAP explainer MAPLE.""" run_iterations(campaign, n_iterations=2, batch_size=1) - for explainer_cls in valid_explainers: - try: - """Ensure that an error is raised if non-computational representation - is used with a non-Kernel SHAP explainer.""" - with pytest.raises( - ValueError, - match=( - "Experimental representation is not supported " - "for non-Kernel SHAP explainer." - ), - ): - diag.explanation( - campaign, - computational_representation=False, - explainer_class=explainer_cls, - ) - - """Test the non-SHAP explainer in computational representation.""" - other_explainer = diag.explanation( + try: + """Ensure that an error is raised if non-computational representation + is used with a non-Kernel SHAP explainer.""" + with pytest.raises( + ValueError, + match=( + "Experimental representation is not supported " + "for non-Kernel SHAP explainer." + ), + ): + diag.explanation( campaign, - computational_representation=True, + computational_representation=False, explainer_class=explainer_cls, ) - assert isinstance(other_explainer, shap.Explanation) - except NotImplementedError: - pass + + """Test the non-SHAP explainer in computational representation.""" + other_explainer = diag.explanation( + campaign, + computational_representation=True, + explainer_class=explainer_cls, + ) + assert isinstance(other_explainer, shap.Explanation) + except NotImplementedError: + pass From 34e44bcacf502e7571dcaecedde1d6f83cd8f9be Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Mon, 18 Nov 2024 17:07:44 +0100 Subject: [PATCH 17/92] Generalized explanation for all shap explainer types, refactoring of optional dependencies, diagnostics subpackage. --- .lockfiles/py310-dev.lock | 19 ++- CHANGELOG.md | 3 +- README.md | 2 +- baybe/_optional/info.py | 2 +- baybe/diagnostics/__init__.py | 12 ++ .../diagnostics.py => diagnostics/shap.py} | 123 ++++++++++----- mypy.ini | 3 + pyproject.toml | 1 + pytest.ini | 3 +- tests/diagnostics/test_shap.py | 144 ++++++++++++++++++ tests/utils/test_diagnostics.py | 113 -------------- 11 files changed, 264 insertions(+), 161 deletions(-) create mode 100644 baybe/diagnostics/__init__.py rename baybe/{utils/diagnostics.py => diagnostics/shap.py} (59%) create mode 100644 tests/diagnostics/test_shap.py delete mode 100644 tests/utils/test_diagnostics.py diff --git a/.lockfiles/py310-dev.lock b/.lockfiles/py310-dev.lock index 8aef1fe6d..d34229b57 100644 --- a/.lockfiles/py310-dev.lock +++ b/.lockfiles/py310-dev.lock @@ -78,7 +78,9 @@ click==8.1.7 # pydoclint # streamlit cloudpickle==3.0.0 - # via dask + # via + # dask + # shap colorama==0.4.6 # via # click @@ -329,6 +331,8 @@ linear-operator==0.5.2 # via # botorch # gpytorch +llvmlite==0.43.0 + # via numba locket==1.0.0 # via partd markdown-it-py==3.0.0 @@ -408,6 +412,8 @@ notebook-shim==0.2.4 # via # jupyterlab # notebook +numba==0.60.0 + # via shap numpy==1.26.4 # via # baybe (pyproject.toml) @@ -421,6 +427,7 @@ numpy==1.26.4 # matplotlib # mordredcommunity # ngboost + # numba # onnx # onnxconverter-common # onnxruntime @@ -435,6 +442,7 @@ numpy==1.26.4 # scikit-learn-extra # scipy # seaborn + # shap # streamlit # types-seaborn # xarray @@ -544,6 +552,7 @@ packaging==24.1 # qtconsole # qtpy # setuptools-scm + # shap # sphinx # streamlit # tox @@ -557,6 +566,7 @@ pandas==2.2.2 # hypothesis # lifelines # seaborn + # shap # streamlit # xarray # xyzpy @@ -752,6 +762,7 @@ scikit-learn==1.5.1 # gpytorch # ngboost # scikit-learn-extra + # shap # skl2onnx scikit-learn-extra==0.3.0 # via baybe (pyproject.toml) @@ -767,6 +778,7 @@ scipy==1.14.0 # ngboost # scikit-learn # scikit-learn-extra + # shap seaborn==0.13.2 # via baybe (pyproject.toml) send2trash==1.8.3 @@ -777,6 +789,8 @@ setuptools==71.1.0 # setuptools-scm setuptools-scm==8.1.0 # via baybe (pyproject.toml) +shap==0.46.0 + # via baybe (pyproject.toml) six==1.16.0 # via # asttokens @@ -788,6 +802,8 @@ six==1.16.0 # rfc3339-validator skl2onnx==1.17.0 # via baybe (pyproject.toml) +slicer==0.0.8 + # via shap smmap==5.0.1 # via gitdb sniffio==1.3.1 @@ -901,6 +917,7 @@ tqdm==4.66.4 # via # ngboost # pyro-ppl + # shap # xyzpy traitlets==5.14.3 # via diff --git a/CHANGELOG.md b/CHANGELOG.md index 6818339da..186b80d49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- Added SHAP analysis within the new `diagnostics` package. +- `diagnostics` dependency group +- SHAP explanations - `allow_missing` and `allow_extra` keyword arguments to `Objective.transform` ### Deprecations diff --git a/README.md b/README.md index e9060112c..97345730a 100644 --- a/README.md +++ b/README.md @@ -296,7 +296,7 @@ The available groups are: - `mypy`: Required for static type checking. - `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). - `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/) -- `diagnostics`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/) +- `diagnostics`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/)pip install uv - `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. - `test`: Required for running the tests. - `dev`: All of the above plus `tox` and `pip-audit`. For code contributors. diff --git a/baybe/_optional/info.py b/baybe/_optional/info.py index 33b189130..8dc838df4 100644 --- a/baybe/_optional/info.py +++ b/baybe/_optional/info.py @@ -28,7 +28,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 MORDRED_INSTALLED = find_spec("mordred") is not None ONNX_INSTALLED = find_spec("onnxruntime") is not None POLARS_INSTALLED = find_spec("polars") is not None - DIAGNOSTICS_INSTALLED = find_spec("diagnostics") is not None + DIAGNOSTICS_INSTALLED = find_spec("shap") is not None PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None RDKIT_INSTALLED = find_spec("rdkit") is not None diff --git a/baybe/diagnostics/__init__.py b/baybe/diagnostics/__init__.py new file mode 100644 index 000000000..9e79a1a5c --- /dev/null +++ b/baybe/diagnostics/__init__.py @@ -0,0 +1,12 @@ +"""Baybe diagnostics (optional).""" + +from baybe._optional.info import DIAGNOSTICS_INSTALLED + +if DIAGNOSTICS_INSTALLED: + from baybe.diagnostics.shap import explainer, explanation, plot_shap_scatter + +__all__ = [ + "explainer", + "explanation", + "plot_shap_scatter", +] diff --git a/baybe/utils/diagnostics.py b/baybe/diagnostics/shap.py similarity index 59% rename from baybe/utils/diagnostics.py rename to baybe/diagnostics/shap.py index a1e1a4e62..97f375237 100644 --- a/baybe/utils/diagnostics.py +++ b/baybe/diagnostics/shap.py @@ -1,4 +1,4 @@ -"""Diagnostics utilities.""" +"""SHAP utilities.""" import numbers import warnings @@ -32,6 +32,9 @@ def explainer( Raises: ValueError: If no measurements have been provided yet. + NotImplementedError: If the provided explainer does not support + the campaign surrogate. + TypeError: If the provided explainer does not support the campaign surrogate. """ if campaign.measurements.empty: raise ValueError("No measurements have been provided yet.") @@ -54,13 +57,27 @@ def model(x): return output.detach().numpy() - shap_explainer = explainer_class(model, data, **kwargs) + if ( + campaign.searchspace.type != "CONTINUOUS" + and not computational_representation + and not explainer_class == shap.KernelExplainer + ): + raise NotImplementedError( + "Only KernelExplainer is supported for non-continous searchspaces." + ) + + try: + shap_explainer = explainer_class(model, data, **kwargs) + except shap.utils._exceptions.InvalidModelError: + raise TypeError( + "The selected explainer class does not support the campaign surrogate." + ) return shap_explainer def explanation( campaign: Campaign, - data: np.ndarray = None, + data: pd.DataFrame | None = None, explainer_class: shap.Explainer = shap.KernelExplainer, computational_representation: bool = False, **kwargs, @@ -82,8 +99,12 @@ def explanation( The Shapley values for the provided campaign. Raises: + ValueError: If the provided explainer does not support experimental + representation. + NotImplementedError: If the provided explainer does not support + the campaign surrogate. ValueError: If the provided data does not have the same amount of parameters - as previously provided to the explainer. + as the campaign. """ is_shap_explainer = not explainer_class.__module__.startswith( "shap.explainers.other." @@ -95,67 +116,82 @@ def explanation( "supported for non-Kernel SHAP explainer." ) - explainer_obj = explainer( - campaign, - explainer_class=explainer_class, - computational_representation=computational_representation, - **kwargs, - ) + try: + explainer_obj = explainer( + campaign, + explainer_class=explainer_class, + computational_representation=computational_representation, + **kwargs, + ) + except NotImplementedError: + warnings.warn( + "The provided Explainer does not support experimental representation. " + "Switching to computational representation. " + "Otherwise consider using a different explainer (e.g. KernelExplainer)." + ) + return explanation( + campaign, + data=data, + explainer_class=explainer_class, + computational_representation=True, + **kwargs, + ) if data is None: - if isinstance(explainer_obj.data, np.ndarray): - data = explainer_obj.data - else: - data = explainer_obj.data.data - elif computational_representation: - data = campaign.searchspace.transform(data) + data = campaign.measurements[[p.name for p in campaign.parameters]].copy() + elif set(campaign.searchspace.parameter_names) != set(data.columns.values): + raise ValueError( + "The provided data does not have the same amount of parameters " + "as specified for the campaign." + ) + if computational_representation: + data = campaign.searchspace.transform(pd.DataFrame(data)) + + """Get background data depending on the explainer.""" + bg_data = getattr(explainer_obj, "data", getattr(explainer_obj, "masker", None)) + bg_data = getattr(bg_data, "data", bg_data) + + # Type checking for mypy + bg_data = bg_data if isinstance(bg_data, pd.DataFrame) else pd.DataFrame(bg_data) + assert isinstance(data, pd.DataFrame) + + if not bg_data.shape[1] == data.shape[1]: + raise ValueError( + "The provided data does not have the same amount of " + "parameters as the shap explainer background." + ) if not is_shap_explainer: """Return attributions for non-SHAP explainers.""" if explainer_class.__module__.endswith("maple"): """Additional argument for maple to increase comparability to SHAP.""" - attributions = explainer_obj.attributions(data, multiply_by_input=True)[0] + attributions = explainer_obj.attributions( + np.array(data), multiply_by_input=True + )[0] else: - attributions = explainer_obj.attributions(data)[0] + attributions = explainer_obj.attributions(np.array(data))[0] if computational_representation: feature_names = campaign.searchspace.comp_rep_columns else: feature_names = campaign.searchspace.parameter_names explanations = shap.Explanation( values=attributions, - base_values=data, - data=data, + base_values=np.array(data), + data=np.array(data), ) explanations.feature_names = list(feature_names) return explanations - if data.shape[1] != explainer_obj.data.data.shape[1]: - raise ValueError( - "The provided data does not have the same amount " - "of parameters as the shap explainer background." - ) + shap_explanations = explainer_obj(np.array(data)) + if len(shap_explanations.shape) == 2: + return shap_explanations else: shap_explanations = explainer_obj(data)[:, :, 0] return shap_explanations -def plot_beeswarm(explanation: shap.Explanation, **kwargs) -> None: - """Plot the Shapley values using a beeswarm plot.""" - shap.plots.beeswarm(explanation, **kwargs) - - -def plot_waterfall(explanation: shap.Explanation, **kwargs) -> None: - """Plot the Shapley values using a waterfall plot.""" - shap.plots.waterfall(explanation, **kwargs) - - -def plot_bar(explanation: shap.Explanation, **kwargs) -> None: - """Plot the Shapley values using a bar plot.""" - shap.plots.bar(explanation, **kwargs) - - -def plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None: +def plot_shap_scatter(explanation: shap.Explanation, **kwargs) -> None: """Plot the Shapley values using a scatter plot while leaving out string values. Args: @@ -176,7 +212,7 @@ def plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None: def is_not_numeric_column(col): return np.array([not isinstance(v, numbers.Number) for v in col]).any() - if data.ndim == 1: + if np.ndim(data) == 1: if is_not_numeric_column(data): warnings.warn( "Cannot plot scatter plot for the provided " @@ -185,6 +221,9 @@ def is_not_numeric_column(col): else: shap.plots.scatter(explanation, **kwargs) else: + # Type checking for mypy + assert isinstance(data, np.ndarray) + number_enum = [i for i, x in enumerate(data[1]) if not isinstance(x, str)] if len(number_enum) < len(explanation.feature_names): warnings.warn( diff --git a/mypy.ini b/mypy.ini index 1188e9e6e..e697ec5c5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -65,5 +65,8 @@ ignore_missing_imports = True [mypy-polars] ignore_missing_imports = True +[mypy-shap] +ignore_missing_imports = True + [mypy-diagnostics] ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index adfc790bf..9a6beb607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ dev = [ "baybe[polars]", "baybe[simulation]", "baybe[test]", + "baybe[diagnostics]", "pip-audit>=2.5.5", "tox-uv>=1.7.0", "uv>=0.3.0", # `uv lock` (for lockfiles) is stable since 0.3.0: https://github.com/astral-sh/uv/issues/2679#event-13950215962 diff --git a/pytest.ini b/pytest.ini index 283221189..36bc17480 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,8 +10,7 @@ addopts = --ignore=baybe/_optional --ignore=baybe/utils/chemistry.py --ignore=tests/simulate_telemetry.py - --ignore=baybe/utils/diagnostics.py - --ignore=tests/utils/test_diagnostics.py + --ignore=baybe/diagnostics testpaths = baybe tests \ No newline at end of file diff --git a/tests/diagnostics/test_shap.py b/tests/diagnostics/test_shap.py new file mode 100644 index 000000000..25abac227 --- /dev/null +++ b/tests/diagnostics/test_shap.py @@ -0,0 +1,144 @@ +"""Tests for diagnostic utilities.""" + +import inspect + +import pandas as pd +import pytest + +from baybe._optional.info import DIAGNOSTICS_INSTALLED +from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.searchspace import SearchSpaceType +from baybe.utils.basic import get_subclasses +from tests.conftest import run_iterations + +pytestmark = pytest.mark.skipif( + not DIAGNOSTICS_INSTALLED, reason="Optional diagnostics dependency not installed." +) +if DIAGNOSTICS_INSTALLED: + import shap + + from baybe import diagnostics as diag + + def _has_required_init_parameters(cls): + """Helper function checks if initializer has required standard parameters.""" + required_parameters = ["model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == required_parameters + + valid_non_shap_explainers = [ + getattr(shap.explainers.other, cls_name) + for cls_name in shap.explainers.other.__all__ + if _has_required_init_parameters(getattr(shap.explainers.other, cls_name)) + ] + + shap_explainers = [ + getattr(shap.explainers, cls_name) for cls_name in shap.explainers.__all__ + ] + +valid_hybrid_bayesian_recommenders = [ + TwoPhaseMetaRecommender(recommender=cls()) + for cls in get_subclasses(BayesianRecommender) + if cls.compatibility == SearchSpaceType.HYBRID +] + + +def _run_explainer_tests(campaign, explainers, representation_types): + """Helper to test explainers for different representation types.""" + for explainer_cls in explainers: + for representation in representation_types: + shap_val = diag.explanation( + campaign, + computational_representation=representation, + explainer_class=explainer_cls, + ) + assert isinstance(shap_val, shap.Explanation) + + +def _test_shap_explainers(campaign, explainers, check_param_count=False): + run_iterations(campaign, n_iterations=2, batch_size=1) + try: + _run_explainer_tests(campaign, explainers, [False, True]) + if check_param_count: + df = pd.DataFrame({"Num_disc_1": [0, 2]}) + with pytest.raises( + ValueError, + match="The provided data does not have the same " + "amount of parameters as specified for the campaign.", + ): + diag.explanation(campaign, data=df, explainer_class=shap_explainers[0]) + except ModuleNotFoundError as e: + if "No module named 'tensorflow'" in str(e): + pass + else: + raise e + except TypeError as e: + if ( + "The selected explainer class does not support the campaign surrogate." + in str(e) + ): + pass + else: + raise e + + +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], +) +def test_shapley_values_no_measurements(campaign): + """A campaign without measurements raises an error.""" + with pytest.raises(ValueError, match="No measurements have been provided yet."): + diag.explanation(campaign) + + +@pytest.mark.slow +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]]) +def test_shapley_with_measurements_continuous(campaign): + """Test the explain functionalities with measurements.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + for explainer_cls in shap_explainers: + _test_shap_explainers(campaign, shap_explainers) + + +@pytest.mark.slow +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], +) +def test_shapley_with_measurements(campaign): + """Test the explain functionalities with measurements.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + for explainer_cls in shap_explainers: + _test_shap_explainers(campaign, shap_explainers) + + +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], +) +def test_non_shapley_explainers(campaign): + """Test the explain functionalities with the non-SHAP explainer MAPLE.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + + for explainer_cls in valid_non_shap_explainers: + """Ensure that an error is raised if non-computational representation + is used with a non-Kernel SHAP explainer.""" + with pytest.raises( + ValueError, + match=( + "Experimental representation is not supported " + "for non-Kernel SHAP explainer." + ), + ): + diag.explanation( + campaign, + computational_representation=False, + explainer_class=explainer_cls, + ) + + """Test the non-SHAP explainer in computational representation.""" + _run_explainer_tests(campaign, [explainer_cls], [True]) diff --git a/tests/utils/test_diagnostics.py b/tests/utils/test_diagnostics.py deleted file mode 100644 index eed039371..000000000 --- a/tests/utils/test_diagnostics.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Tests for diagnostic utilities.""" - -import inspect - -import pandas as pd -import pytest - -import baybe.utils.diagnostics as diag -from baybe._optional.diagnostics import shap -from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender -from baybe.recommenders.pure.bayesian.base import BayesianRecommender -from baybe.searchspace import SearchSpaceType -from baybe.utils.basic import get_subclasses -from tests.conftest import run_iterations - - -def has_required_init_parameters(cls): - """Helpfer function checks if initializer has required standard parameters.""" - required_parameters = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == required_parameters - - -valid_explainers = [ - getattr(shap.explainers.other, cls_name) - for cls_name in shap.explainers.other.__all__ - if has_required_init_parameters(getattr(shap.explainers.other, cls_name)) -] - -valid_hybrid_bayesian_recommenders = [ - TwoPhaseMetaRecommender(recommender=cls()) - for cls in get_subclasses(BayesianRecommender) - if cls.compatibility == SearchSpaceType.HYBRID -] - - -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], -) -def test_shapley_values_no_measurements(campaign): - """A campaign without measurements raises an error.""" - with pytest.raises(ValueError, match="No measurements have been provided yet."): - diag.explanation(campaign) - - -@pytest.mark.slow -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], -) -def test_shapley_with_measurements(campaign): - """Test the explain functionalities with measurements.""" - """Test the default explainer in experimental - and computational representations.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - - for computational_representation in [False, True]: - shap_val = diag.explanation( - campaign, - computational_representation=computational_representation, - ) - assert isinstance(shap_val, shap.Explanation) - - """Ensure that an error is raised if the data - to be explained has a different number of parameters.""" - df = pd.DataFrame({"Num_disc_1": [0, 2]}) - with pytest.raises( - ValueError, - match=( - "The provided data does not have the same " - "amount of parameters as the shap explainer background." - ), - ): - diag.explanation(campaign, data=df) - - -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], -) -def test_non_shapley_explainers(campaign): - """Test the explain functionalities with the non-SHAP explainer MAPLE.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - - for explainer_cls in valid_explainers: - try: - """Ensure that an error is raised if non-computational representation - is used with a non-Kernel SHAP explainer.""" - with pytest.raises( - ValueError, - match=( - "Experimental representation is not supported " - "for non-Kernel SHAP explainer." - ), - ): - diag.explanation( - campaign, - computational_representation=False, - explainer_class=explainer_cls, - ) - - """Test the non-SHAP explainer in computational representation.""" - other_explainer = diag.explanation( - campaign, - computational_representation=True, - explainer_class=explainer_cls, - ) - assert isinstance(other_explainer, shap.Explanation) - except NotImplementedError: - pass From 7823e87f9158dae1e9edbd44bd551d16e3089ff8 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Wed, 20 Nov 2024 00:58:39 +0100 Subject: [PATCH 18/92] Reworked tests from feedback, Cleanup for Review --- baybe/diagnostics/shap.py | 37 +++++---- tests/diagnostics/test_shap.py | 120 +++++++++++++++-------------- tests/utils/test_diagnostics.py | 131 -------------------------------- 3 files changed, 83 insertions(+), 205 deletions(-) delete mode 100644 tests/utils/test_diagnostics.py diff --git a/baybe/diagnostics/shap.py b/baybe/diagnostics/shap.py index 97f375237..9ed170063 100644 --- a/baybe/diagnostics/shap.py +++ b/baybe/diagnostics/shap.py @@ -105,6 +105,7 @@ def explanation( the campaign surrogate. ValueError: If the provided data does not have the same amount of parameters as the campaign. + ValueError: If the resulting explanation object has an invalid shape. """ is_shap_explainer = not explainer_class.__module__.startswith( "shap.explainers.other." @@ -147,9 +148,10 @@ def explanation( if computational_representation: data = campaign.searchspace.transform(pd.DataFrame(data)) - """Get background data depending on the explainer.""" + """Get background data regardless of the explainer.""" bg_data = getattr(explainer_obj, "data", getattr(explainer_obj, "masker", None)) - bg_data = getattr(bg_data, "data", bg_data) + if not isinstance(bg_data, (np.ndarray, pd.DataFrame)): + bg_data = getattr(bg_data, "data") # Type checking for mypy bg_data = bg_data if isinstance(bg_data, pd.DataFrame) else pd.DataFrame(bg_data) @@ -161,6 +163,7 @@ def explanation( "parameters as the shap explainer background." ) + data_array = np.array(data) if not is_shap_explainer: """Return attributions for non-SHAP explainers.""" if explainer_class.__module__.endswith("maple"): @@ -170,28 +173,32 @@ def explanation( )[0] else: attributions = explainer_obj.attributions(np.array(data))[0] - if computational_representation: - feature_names = campaign.searchspace.comp_rep_columns - else: - feature_names = campaign.searchspace.parameter_names + feature_names = ( + campaign.searchspace.comp_rep_columns + if computational_representation + else campaign.searchspace.parameter_names + ) explanations = shap.Explanation( values=attributions, - base_values=np.array(data), - data=np.array(data), + base_values=data_array, + data=data_array, ) explanations.feature_names = list(feature_names) return explanations - - shap_explanations = explainer_obj(np.array(data)) - if len(shap_explanations.shape) == 2: - return shap_explanations else: - shap_explanations = explainer_obj(data)[:, :, 0] + explanations = explainer_obj(data_array) - return shap_explanations + if len(explanations.shape) == 2: + return explanations + if len(explanations.shape) == 3: + return explanations[:, :, 0] + raise ValueError( + "The Explanation has an invalid " + f"dimensionality of {len(explanations.shape)}." + ) -def plot_shap_scatter(explanation: shap.Explanation, **kwargs) -> None: +def plot_shap_scatter(explanation: memoryview | shap.Explanation, **kwargs) -> None: """Plot the Shapley values using a scatter plot while leaving out string values. Args: diff --git a/tests/diagnostics/test_shap.py b/tests/diagnostics/test_shap.py index 25abac227..c80d5e034 100644 --- a/tests/diagnostics/test_shap.py +++ b/tests/diagnostics/test_shap.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from pytest import param from baybe._optional.info import DIAGNOSTICS_INSTALLED from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender @@ -20,59 +21,59 @@ from baybe import diagnostics as diag + EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] + def _has_required_init_parameters(cls): """Helper function checks if initializer has required standard parameters.""" - required_parameters = ["model", "data"] + required_parameters = ["self", "model", "data"] init_signature = inspect.signature(cls.__init__) parameters = list(init_signature.parameters.keys()) return parameters[:3] == required_parameters - valid_non_shap_explainers = [ - getattr(shap.explainers.other, cls_name) + non_shap_explainers = [ + param(explainer, id=f"{cls_name}") for cls_name in shap.explainers.other.__all__ - if _has_required_init_parameters(getattr(shap.explainers.other, cls_name)) + if _has_required_init_parameters( + explainer := getattr(shap.explainers.other, cls_name) + ) + and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) ] shap_explainers = [ - getattr(shap.explainers, cls_name) for cls_name in shap.explainers.__all__ + param(getattr(shap.explainers, cls_name), id=f"{cls_name}") + for cls_name in shap.explainers.__all__ + if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) ] valid_hybrid_bayesian_recommenders = [ - TwoPhaseMetaRecommender(recommender=cls()) + param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") for cls in get_subclasses(BayesianRecommender) if cls.compatibility == SearchSpaceType.HYBRID ] -def _run_explainer_tests(campaign, explainers, representation_types): - """Helper to test explainers for different representation types.""" - for explainer_cls in explainers: - for representation in representation_types: - shap_val = diag.explanation( +def _test_explainer(campaign, explainer_cls, use_comp_rep): + """Helper function for general explainer tests.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + try: + shap_val = diag.explanation( + campaign, + computational_representation=use_comp_rep, + explainer_class=explainer_cls, + ) + assert isinstance(shap_val, shap.Explanation) + df = pd.DataFrame({"Num_disc_1": [0, 2]}) + with pytest.raises( + ValueError, + match="The provided data does not have the same " + "amount of parameters as specified for the campaign.", + ): + diag.explanation( campaign, - computational_representation=representation, + data=df, + computational_representation=True, explainer_class=explainer_cls, ) - assert isinstance(shap_val, shap.Explanation) - - -def _test_shap_explainers(campaign, explainers, check_param_count=False): - run_iterations(campaign, n_iterations=2, batch_size=1) - try: - _run_explainer_tests(campaign, explainers, [False, True]) - if check_param_count: - df = pd.DataFrame({"Num_disc_1": [0, 2]}) - with pytest.raises( - ValueError, - match="The provided data does not have the same " - "amount of parameters as specified for the campaign.", - ): - diag.explanation(campaign, data=df, explainer_class=shap_explainers[0]) - except ModuleNotFoundError as e: - if "No module named 'tensorflow'" in str(e): - pass - else: - raise e except TypeError as e: if ( "The selected explainer class does not support the campaign surrogate." @@ -95,50 +96,51 @@ def test_shapley_values_no_measurements(campaign): @pytest.mark.slow @pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("explainer_cls", shap_explainers) +@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) @pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]]) -def test_shapley_with_measurements_continuous(campaign): +def test_shapley_with_measurements_continuous(campaign, explainer_cls, use_comp_rep): """Test the explain functionalities with measurements.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - for explainer_cls in shap_explainers: - _test_shap_explainers(campaign, shap_explainers) + _test_explainer(campaign, explainer_cls, use_comp_rep) @pytest.mark.slow @pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("explainer_cls", shap_explainers) +@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) @pytest.mark.parametrize( "parameter_names", [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], ) -def test_shapley_with_measurements(campaign): +def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): """Test the explain functionalities with measurements.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - for explainer_cls in shap_explainers: - _test_shap_explainers(campaign, shap_explainers) + _test_explainer(campaign, explainer_cls, use_comp_rep) +@pytest.mark.parametrize("explainer_cls", non_shap_explainers) @pytest.mark.parametrize( "parameter_names", [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], + ids=["params1"], ) -def test_non_shapley_explainers(campaign): +def test_non_shapley_explainers(campaign, explainer_cls): """Test the explain functionalities with the non-SHAP explainer MAPLE.""" run_iterations(campaign, n_iterations=2, batch_size=1) - for explainer_cls in valid_non_shap_explainers: - """Ensure that an error is raised if non-computational representation - is used with a non-Kernel SHAP explainer.""" - with pytest.raises( - ValueError, - match=( - "Experimental representation is not supported " - "for non-Kernel SHAP explainer." - ), - ): - diag.explanation( - campaign, - computational_representation=False, - explainer_class=explainer_cls, - ) - - """Test the non-SHAP explainer in computational representation.""" - _run_explainer_tests(campaign, [explainer_cls], [True]) + """Ensure that an error is raised if non-computational representation + is used with a non-Kernel SHAP explainer.""" + with pytest.raises( + ValueError, + match=( + "Experimental representation is not supported " + "for non-Kernel SHAP explainer." + ), + ): + diag.explanation( + campaign, + computational_representation=False, + explainer_class=explainer_cls, + ) + + """Test the non-SHAP explainer in computational representation.""" + _test_explainer(campaign, explainer_cls, use_comp_rep=True) diff --git a/tests/utils/test_diagnostics.py b/tests/utils/test_diagnostics.py deleted file mode 100644 index 4b2f016f3..000000000 --- a/tests/utils/test_diagnostics.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Tests for diagnostic utilities.""" - -import inspect - -import pandas as pd -import pytest -from pytest import param - -import baybe.utils.diagnostics as diag -from baybe._optional.diagnostics import shap -from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender -from baybe.recommenders.pure.bayesian.base import BayesianRecommender -from baybe.searchspace import SearchSpaceType -from baybe.utils.basic import get_subclasses -from tests.conftest import run_iterations - - -def _has_required_init_parameters(cls): - """Helper function checks if initializer has required standard parameters.""" - required_parameters = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == required_parameters - - -non_shap_explainers = [ - param(explainer, id=f"{explainer.__name__}") - for cls_name in shap.explainers.other.__all__ - if _has_required_init_parameters( - explainer := getattr(shap.explainers.other, cls_name) - ) -] - - -shap_explainers = [ - param(explainer, id=f"{explainer.__name__}") - for cls_name in shap.explainers.__all__ - if _has_required_init_parameters(explainer := getattr(shap.explainers, cls_name)) - and all( - x not in explainer.__name__ - for x in ["Tree", "GPU", "Gradient", "Sampling", "Deep"] - ) -] - -valid_hybrid_bayesian_recommenders = [ - param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") - for cls in get_subclasses(BayesianRecommender) - if cls.compatibility == SearchSpaceType.HYBRID -] - - -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], -) -def test_shapley_values_no_measurements(campaign): - """A campaign without measurements raises an error.""" - with pytest.raises(ValueError, match="No measurements have been provided yet."): - diag.explanation(campaign) - - -@pytest.mark.slow -@pytest.mark.parametrize("explainer_cls", shap_explainers) -@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], - ids=["params1"], -) -def test_shapley_with_measurements(campaign, use_comp_rep, explainer_cls): - """Test the explain functionalities with measurements.""" - """Test the default explainer in experimental - and computational representations.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - - shap_val = diag.explanation( - campaign, - computational_representation=use_comp_rep, - explainer_class=explainer_cls, - ) - assert isinstance(shap_val, shap.Explanation) - - """Ensure that an error is raised if the data - to be explained has a different number of parameters.""" - df = pd.DataFrame({"Num_disc_1": [0, 2]}) - with pytest.raises( - ValueError, - match=( - "The provided data does not have the same " - "amount of parameters as the shap explainer background." - ), - ): - diag.explanation(campaign, data=df) - - -@pytest.mark.parametrize("explainer_cls", non_shap_explainers) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], - ids=["params1"], -) -def test_non_shapley_explainers(campaign, explainer_cls): - """Test the explain functionalities with the non-SHAP explainer MAPLE.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - - try: - """Ensure that an error is raised if non-computational representation - is used with a non-Kernel SHAP explainer.""" - with pytest.raises( - ValueError, - match=( - "Experimental representation is not supported " - "for non-Kernel SHAP explainer." - ), - ): - diag.explanation( - campaign, - computational_representation=False, - explainer_class=explainer_cls, - ) - - """Test the non-SHAP explainer in computational representation.""" - other_explainer = diag.explanation( - campaign, - computational_representation=True, - explainer_class=explainer_cls, - ) - assert isinstance(other_explainer, shap.Explanation) - except NotImplementedError: - pass From 0c2deb6dd144b396e2beb54d644cd6b6b11841f7 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Wed, 20 Nov 2024 10:56:01 +0100 Subject: [PATCH 19/92] Further cleanup --- baybe/diagnostics/shap.py | 37 ++++++++++++++++++++-------------- tests/diagnostics/test_shap.py | 36 ++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/baybe/diagnostics/shap.py b/baybe/diagnostics/shap.py index 9ed170063..46da569de 100644 --- a/baybe/diagnostics/shap.py +++ b/baybe/diagnostics/shap.py @@ -11,6 +11,22 @@ from baybe.utils.dataframe import to_tensor +def _get_explainer_feature_len(explainer: shap.Explainer) -> int: + """Get dimensionality of explainer background data. + + Args: + explainer: The explainer object as implemented in the SHAP package. + + Returns: + The dimensionality of the background data. + """ + bg_data = getattr(explainer, "data", getattr(explainer, "masker", None)) + if not isinstance(bg_data, pd.DataFrame): + bg_data = getattr(bg_data, "data") + + return bg_data.shape[1] + + def explainer( campaign: Campaign, explainer_class: shap.Explainer = shap.KernelExplainer, @@ -148,31 +164,22 @@ def explanation( if computational_representation: data = campaign.searchspace.transform(pd.DataFrame(data)) - """Get background data regardless of the explainer.""" - bg_data = getattr(explainer_obj, "data", getattr(explainer_obj, "masker", None)) - if not isinstance(bg_data, (np.ndarray, pd.DataFrame)): - bg_data = getattr(bg_data, "data") - # Type checking for mypy - bg_data = bg_data if isinstance(bg_data, pd.DataFrame) else pd.DataFrame(bg_data) assert isinstance(data, pd.DataFrame) - if not bg_data.shape[1] == data.shape[1]: + if not _get_explainer_feature_len(explainer_obj) == data.shape[1]: raise ValueError( "The provided data does not have the same amount of " "parameters as the shap explainer background." ) - data_array = np.array(data) if not is_shap_explainer: """Return attributions for non-SHAP explainers.""" if explainer_class.__module__.endswith("maple"): """Additional argument for maple to increase comparability to SHAP.""" - attributions = explainer_obj.attributions( - np.array(data), multiply_by_input=True - )[0] + attributions = explainer_obj.attributions(data, multiply_by_input=True)[0] else: - attributions = explainer_obj.attributions(np.array(data))[0] + attributions = explainer_obj.attributions(data)[0] feature_names = ( campaign.searchspace.comp_rep_columns if computational_representation @@ -180,13 +187,13 @@ def explanation( ) explanations = shap.Explanation( values=attributions, - base_values=data_array, - data=data_array, + base_values=data, + data=data, ) explanations.feature_names = list(feature_names) return explanations else: - explanations = explainer_obj(data_array) + explanations = explainer_obj(data) if len(explanations.shape) == 2: return explanations diff --git a/tests/diagnostics/test_shap.py b/tests/diagnostics/test_shap.py index c80d5e034..b07e75ad2 100644 --- a/tests/diagnostics/test_shap.py +++ b/tests/diagnostics/test_shap.py @@ -16,21 +16,25 @@ pytestmark = pytest.mark.skipif( not DIAGNOSTICS_INSTALLED, reason="Optional diagnostics dependency not installed." ) + if DIAGNOSTICS_INSTALLED: import shap from baybe import diagnostics as diag - EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] +EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] + + +def _has_required_init_parameters(cls): + """Helper function checks if initializer has required standard parameters.""" + REQUIRED_PARAMETERS = ["self", "model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == REQUIRED_PARAMETERS - def _has_required_init_parameters(cls): - """Helper function checks if initializer has required standard parameters.""" - required_parameters = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == required_parameters - non_shap_explainers = [ +non_shap_explainers = ( + [ param(explainer, id=f"{cls_name}") for cls_name in shap.explainers.other.__all__ if _has_required_init_parameters( @@ -38,12 +42,19 @@ def _has_required_init_parameters(cls): ) and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) ] + if DIAGNOSTICS_INSTALLED + else [] +) - shap_explainers = [ +shap_explainers = ( + [ param(getattr(shap.explainers, cls_name), id=f"{cls_name}") for cls_name in shap.explainers.__all__ if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) ] + if DIAGNOSTICS_INSTALLED + else [] +) valid_hybrid_bayesian_recommenders = [ param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") @@ -125,8 +136,8 @@ def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): ) def test_non_shapley_explainers(campaign, explainer_cls): """Test the explain functionalities with the non-SHAP explainer MAPLE.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - + """Test the non-SHAP explainer in computational representation.""" + _test_explainer(campaign, explainer_cls, use_comp_rep=True) """Ensure that an error is raised if non-computational representation is used with a non-Kernel SHAP explainer.""" with pytest.raises( @@ -141,6 +152,3 @@ def test_non_shapley_explainers(campaign, explainer_cls): computational_representation=False, explainer_class=explainer_cls, ) - - """Test the non-SHAP explainer in computational representation.""" - _test_explainer(campaign, explainer_cls, use_comp_rep=True) From 436c08fc142632e8234538f1406db819e1d49f3c Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Mon, 9 Dec 2024 23:35:09 +0100 Subject: [PATCH 20/92] Renaming of "diagnostics" package into "insights", Addition of Insight and SHAPInsight classes, refactoring of SHAP functions into SHAPInsight class. --- CHANGELOG.md | 2 +- CONTRIBUTORS.md | 2 +- README.md | 2 +- baybe/_optional/diagnostics.py | 16 -- baybe/_optional/info.py | 2 +- baybe/_optional/insights.py | 16 ++ baybe/diagnostics/__init__.py | 12 -- baybe/diagnostics/shap.py | 247 ------------------------- baybe/insights/__init__.py | 12 ++ baybe/insights/base.py | 45 +++++ baybe/insights/shap.py | 319 +++++++++++++++++++++++++++++++++ pyproject.toml | 6 +- pytest.ini | 2 +- tests/diagnostics/test_shap.py | 154 ---------------- tests/insights/test_shap.py | 198 ++++++++++++++++++++ tox.ini | 2 +- 16 files changed, 599 insertions(+), 438 deletions(-) delete mode 100644 baybe/_optional/diagnostics.py create mode 100644 baybe/_optional/insights.py delete mode 100644 baybe/diagnostics/__init__.py delete mode 100644 baybe/diagnostics/shap.py create mode 100644 baybe/insights/__init__.py create mode 100644 baybe/insights/base.py create mode 100644 baybe/insights/shap.py delete mode 100644 tests/diagnostics/test_shap.py create mode 100644 tests/insights/test_shap.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 186b80d49..58138310d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- `diagnostics` dependency group +- `insights` dependency group - SHAP explanations - `allow_missing` and `allow_extra` keyword arguments to `Objective.transform` diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 140817a83..530f51d9d 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -27,4 +27,4 @@ - Julian Streibel (Merck Life Science KGaA, Darmstadt, Germany):\ Bernoulli multi-armed bandit and Thompson sampling - Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dรผbendorf, Switzerland):\ - SHAP explainers for diagnoatics \ No newline at end of file + SHAP explainers for insights \ No newline at end of file diff --git a/README.md b/README.md index 97345730a..cebeb1b13 100644 --- a/README.md +++ b/README.md @@ -296,7 +296,7 @@ The available groups are: - `mypy`: Required for static type checking. - `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). - `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/) -- `diagnostics`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/)pip install uv +- `insights`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/)pip install uv - `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. - `test`: Required for running the tests. - `dev`: All of the above plus `tox` and `pip-audit`. For code contributors. diff --git a/baybe/_optional/diagnostics.py b/baybe/_optional/diagnostics.py deleted file mode 100644 index 08138f657..000000000 --- a/baybe/_optional/diagnostics.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Optional import for diagnostics utilities.""" - -from baybe.exceptions import OptionalImportError - -try: - import shap -except ModuleNotFoundError as ex: - raise OptionalImportError( - "Explainer functionality is unavailable because 'diagnostics' is not installed." - " Consider installing BayBE with 'diagnostics' dependency, e.g. via " - "`pip install baybe[diagnostics]`." - ) from ex - -__all__ = [ - "shap", -] diff --git a/baybe/_optional/info.py b/baybe/_optional/info.py index db7518aca..6cbe0abc1 100644 --- a/baybe/_optional/info.py +++ b/baybe/_optional/info.py @@ -46,7 +46,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 # Package combinations CHEM_INSTALLED = MORDRED_INSTALLED and RDKIT_INSTALLED -DIAGNOSTICS_INSTALLED = SHAP_INSTALLED +INSIGHTS_INSTALLED = SHAP_INSTALLED LINT_INSTALLED = all( ( FLAKE8_INSTALLED, diff --git a/baybe/_optional/insights.py b/baybe/_optional/insights.py new file mode 100644 index 000000000..0f83786bf --- /dev/null +++ b/baybe/_optional/insights.py @@ -0,0 +1,16 @@ +"""Optional import for insight subpackage.""" + +from baybe.exceptions import OptionalImportError + +try: + import shap +except ModuleNotFoundError as ex: + raise OptionalImportError( + "Explainer functionality is unavailable because 'insights' is not installed." + " Consider installing BayBE with 'insights' dependency, e.g. via " + "`pip install baybe[insights]`." + ) from ex + +__all__ = [ + "shap", +] diff --git a/baybe/diagnostics/__init__.py b/baybe/diagnostics/__init__.py deleted file mode 100644 index 9e79a1a5c..000000000 --- a/baybe/diagnostics/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Baybe diagnostics (optional).""" - -from baybe._optional.info import DIAGNOSTICS_INSTALLED - -if DIAGNOSTICS_INSTALLED: - from baybe.diagnostics.shap import explainer, explanation, plot_shap_scatter - -__all__ = [ - "explainer", - "explanation", - "plot_shap_scatter", -] diff --git a/baybe/diagnostics/shap.py b/baybe/diagnostics/shap.py deleted file mode 100644 index 46da569de..000000000 --- a/baybe/diagnostics/shap.py +++ /dev/null @@ -1,247 +0,0 @@ -"""SHAP utilities.""" - -import numbers -import warnings - -import numpy as np -import pandas as pd - -from baybe import Campaign -from baybe._optional.diagnostics import shap -from baybe.utils.dataframe import to_tensor - - -def _get_explainer_feature_len(explainer: shap.Explainer) -> int: - """Get dimensionality of explainer background data. - - Args: - explainer: The explainer object as implemented in the SHAP package. - - Returns: - The dimensionality of the background data. - """ - bg_data = getattr(explainer, "data", getattr(explainer, "masker", None)) - if not isinstance(bg_data, pd.DataFrame): - bg_data = getattr(bg_data, "data") - - return bg_data.shape[1] - - -def explainer( - campaign: Campaign, - explainer_class: shap.Explainer = shap.KernelExplainer, - computational_representation: bool = False, - **kwargs, -) -> shap.Explainer: - """Create an explainer for the provided campaign. - - Args: - campaign: The campaign to be explained. - explainer_class: The explainer to be used. Default is shap.KernelExplainer. - computational_representation: Whether to compute the Shapley values - in computational or experimental searchspace. - Default is False. - **kwargs: Additional keyword arguments to be passed to the explainer. - - Returns: - The explainer for the provided campaign. - - Raises: - ValueError: If no measurements have been provided yet. - NotImplementedError: If the provided explainer does not support - the campaign surrogate. - TypeError: If the provided explainer does not support the campaign surrogate. - """ - if campaign.measurements.empty: - raise ValueError("No measurements have been provided yet.") - - data = campaign.measurements[[p.name for p in campaign.parameters]].copy() - - if computational_representation: - data = campaign.searchspace.transform(data) - - def model(x): - tensor = to_tensor(x) - output = campaign.get_surrogate()._posterior_comp(tensor).mean - - return output.detach().numpy() - else: - - def model(x): - df = pd.DataFrame(x, columns=data.columns) - output = campaign.get_surrogate().posterior(df).mean - - return output.detach().numpy() - - if ( - campaign.searchspace.type != "CONTINUOUS" - and not computational_representation - and not explainer_class == shap.KernelExplainer - ): - raise NotImplementedError( - "Only KernelExplainer is supported for non-continous searchspaces." - ) - - try: - shap_explainer = explainer_class(model, data, **kwargs) - except shap.utils._exceptions.InvalidModelError: - raise TypeError( - "The selected explainer class does not support the campaign surrogate." - ) - return shap_explainer - - -def explanation( - campaign: Campaign, - data: pd.DataFrame | None = None, - explainer_class: shap.Explainer = shap.KernelExplainer, - computational_representation: bool = False, - **kwargs, -) -> shap.Explanation: - """Compute the Shapley values for the provided campaign and data. - - Args: - campaign: The campaign to be explained. - data: The data to be explained. - Default is None which uses the campaign's measurements. - explainer_class: The explainer to be used. - Default is shap.KernelExplainer. - computational_representation: Whether to compute the Shapley values - in computational or experimental searchspace. - Default is False. - **kwargs: Additional keyword arguments to be passed to the explainer. - - Returns: - The Shapley values for the provided campaign. - - Raises: - ValueError: If the provided explainer does not support experimental - representation. - NotImplementedError: If the provided explainer does not support - the campaign surrogate. - ValueError: If the provided data does not have the same amount of parameters - as the campaign. - ValueError: If the resulting explanation object has an invalid shape. - """ - is_shap_explainer = not explainer_class.__module__.startswith( - "shap.explainers.other." - ) - - if not is_shap_explainer and not computational_representation: - raise ValueError( - "Experimental representation is not " - "supported for non-Kernel SHAP explainer." - ) - - try: - explainer_obj = explainer( - campaign, - explainer_class=explainer_class, - computational_representation=computational_representation, - **kwargs, - ) - except NotImplementedError: - warnings.warn( - "The provided Explainer does not support experimental representation. " - "Switching to computational representation. " - "Otherwise consider using a different explainer (e.g. KernelExplainer)." - ) - return explanation( - campaign, - data=data, - explainer_class=explainer_class, - computational_representation=True, - **kwargs, - ) - - if data is None: - data = campaign.measurements[[p.name for p in campaign.parameters]].copy() - elif set(campaign.searchspace.parameter_names) != set(data.columns.values): - raise ValueError( - "The provided data does not have the same amount of parameters " - "as specified for the campaign." - ) - if computational_representation: - data = campaign.searchspace.transform(pd.DataFrame(data)) - - # Type checking for mypy - assert isinstance(data, pd.DataFrame) - - if not _get_explainer_feature_len(explainer_obj) == data.shape[1]: - raise ValueError( - "The provided data does not have the same amount of " - "parameters as the shap explainer background." - ) - - if not is_shap_explainer: - """Return attributions for non-SHAP explainers.""" - if explainer_class.__module__.endswith("maple"): - """Additional argument for maple to increase comparability to SHAP.""" - attributions = explainer_obj.attributions(data, multiply_by_input=True)[0] - else: - attributions = explainer_obj.attributions(data)[0] - feature_names = ( - campaign.searchspace.comp_rep_columns - if computational_representation - else campaign.searchspace.parameter_names - ) - explanations = shap.Explanation( - values=attributions, - base_values=data, - data=data, - ) - explanations.feature_names = list(feature_names) - return explanations - else: - explanations = explainer_obj(data) - - if len(explanations.shape) == 2: - return explanations - if len(explanations.shape) == 3: - return explanations[:, :, 0] - raise ValueError( - "The Explanation has an invalid " - f"dimensionality of {len(explanations.shape)}." - ) - - -def plot_shap_scatter(explanation: memoryview | shap.Explanation, **kwargs) -> None: - """Plot the Shapley values using a scatter plot while leaving out string values. - - Args: - explanation: The Shapley values to be plotted. - **kwargs: Additional keyword arguments to be passed to the scatter plot. - - Raises: - ValueError: If the provided explanation object does not match the - required types. - """ - if isinstance(explanation, memoryview): - data = explanation.obj - elif isinstance(explanation, shap.Explanation): - data = explanation.data.data.obj - else: - raise ValueError("The provided explanation argument is not of a valid type.") - - def is_not_numeric_column(col): - return np.array([not isinstance(v, numbers.Number) for v in col]).any() - - if np.ndim(data) == 1: - if is_not_numeric_column(data): - warnings.warn( - "Cannot plot scatter plot for the provided " - "explanation as it contains non-numeric values." - ) - else: - shap.plots.scatter(explanation, **kwargs) - else: - # Type checking for mypy - assert isinstance(data, np.ndarray) - - number_enum = [i for i, x in enumerate(data[1]) if not isinstance(x, str)] - if len(number_enum) < len(explanation.feature_names): - warnings.warn( - "Cannot plot SHAP scatter plot for all " - "parameters as some contain non-numeric values." - ) - shap.plots.scatter(explanation[:, number_enum], **kwargs) diff --git a/baybe/insights/__init__.py b/baybe/insights/__init__.py new file mode 100644 index 000000000..0cc340c51 --- /dev/null +++ b/baybe/insights/__init__.py @@ -0,0 +1,12 @@ +"""Baybe insights (optional).""" + +from baybe._optional.info import INSIGHTS_INSTALLED + +if INSIGHTS_INSTALLED: + from baybe.insights.base import Insight + from baybe.insights.shap import SHAPInsight + +__all__ = [ + "SHAPInsight", + "Insight", +] diff --git a/baybe/insights/base.py b/baybe/insights/base.py new file mode 100644 index 000000000..60a7e81ca --- /dev/null +++ b/baybe/insights/base.py @@ -0,0 +1,45 @@ +"""Base class for all insights.""" + +from abc import ABC + +import pandas as pd + +from baybe import Campaign +from baybe._optional.info import INSIGHTS_INSTALLED +from baybe.objectives.base import Objective +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.searchspace import SearchSpace + +if INSIGHTS_INSTALLED: + pass + + +class Insight(ABC): + """Base class for all insights.""" + + def __init__(self, surrogate): + self.surrogate = surrogate + + @classmethod + def from_campaign(cls, campaign: Campaign): + """Create an insight from a campaign.""" + return cls(campaign.get_surrogate()) + + @classmethod + def from_recommender( + cls, + recommender: BayesianRecommender, + searchspace: SearchSpace, + objective: Objective, + bg_data: pd.DataFrame, + ): + """Create an insight from a recommender.""" + if not hasattr(recommender, "get_surrogate"): + raise ValueError( + "The provided recommender does not provide a surrogate model." + ) + surrogate_model = recommender.get_surrogate(searchspace, objective, bg_data) + + return cls( + surrogate_model, + ) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py new file mode 100644 index 000000000..d431b93ae --- /dev/null +++ b/baybe/insights/shap.py @@ -0,0 +1,319 @@ +"""SHAP utilities.""" + +import numbers +import warnings + +import numpy as np +import pandas as pd +from typing_extensions import override + +from baybe import Campaign +from baybe._optional.insights import shap +from baybe.insights.base import Insight +from baybe.objectives.base import Objective +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.searchspace import SearchSpace +from baybe.utils.dataframe import to_tensor + + +class SHAPInsight(Insight): + """Base class for all SHAP insights.""" + + DEFAULT_SHAP_PLOTS = [ + "bar", + "scatter", + "heatmap", + "force", + "beeswarm", + ] + + def __init__( + self, + surrogate_model, + bg_data: pd.DataFrame, + explained_data: pd.DataFrame | None = None, + explainer_class: shap.Explainer = shap.KernelExplainer, + computational_representation: bool = False, + ): + super().__init__(surrogate_model) + self._computational_representation = computational_representation + self._is_shap_explainer = not explainer_class.__module__.startswith( + "shap.explainers.other." + ) + self._bg_data = bg_data + self._explained_data = explained_data + self.explainer = self._get_explainer(bg_data, explainer_class) + self._explanation = None + + @override + @classmethod + def from_campaign( + cls, + campaign: Campaign, + explainer_class: shap.Explainer = shap.KernelExplainer, + computational_representation: bool = False, + ): + """Create a SHAP insight from a campaign. + + Args: + campaign: The campaign to be used for the SHAP insight. + explainer_class: The explainer class to be used for the computation. + computational_representation: + Whether to use the computational representation. + + Returns: + SHAPInsight: The SHAP insight object. + + Raises: + ValueError: If the campaign does not contain any measurements. + """ + if campaign.measurements.empty: + raise ValueError("The campaign does not contain any measurements.") + data = campaign.measurements[[p.name for p in campaign.parameters]].copy() + return cls( + campaign.get_surrogate(), + bg_data=campaign.searchspace.transform(data) + if computational_representation + else data, + explainer_class=explainer_class, + computational_representation=computational_representation, + ) + + @override + @classmethod + def from_recommender( + cls, + recommender: BayesianRecommender, + searchspace: SearchSpace, + objective: Objective, + bg_data: pd.DataFrame, + explained_data: pd.DataFrame | None = None, + explainer_class: shap.Explainer = shap.KernelExplainer, + computational_representation: bool = False, + ): + """Create a SHAP insight from a recommender. + + Args: + recommender: The recommender to be used for the SHAP insight. + searchspace: The searchspace for the recommender. + objective: The objective for the recommender. + bg_data: The background data set for Explainer. + This is also the measurement data set for the recommender. + explained_data: The data set to be explained. If None, + the background data set is used. + explainer_class: The explainer class. + computational_representation: + Whether to use the computational representation. + + Returns: + SHAPInsight: The SHAP insight object. + + Raises: + ValueError: If the recommender has not implemented a "get_surrogate" method. + """ + if not hasattr(recommender, "get_surrogate"): + raise ValueError( + "The provided recommender does not provide a surrogate model." + ) + surrogate_model = recommender.get_surrogate(searchspace, objective, bg_data) + + return cls( + surrogate_model, + bg_data=searchspace.transform(bg_data) + if computational_representation + else bg_data, + explained_data=explained_data, + explainer_class=explainer_class, + computational_representation=computational_representation, + ) + + def _get_explainer( + self, + data: pd.DataFrame, + explainer_class: type[shap.Explainer] = shap.KernelExplainer, + **kwargs, + ) -> shap.Explainer: + """Create an explainer for the provided campaign. + + Args: + data: The background data set. + explainer_class: The explainer class to be used. + **kwargs: Additional keyword arguments to be passed to the explainer. + + Returns: + shap.Explainer: The created explainer object. + + Raises: + NotImplementedError: If the provided explainer class does + not support the experimental representation. + ValueError: If the provided background data set is empty. + TypeError: If the provided explainer class does not + support the campaign surrogate. + """ + if not self._is_shap_explainer and not self._computational_representation: + raise NotImplementedError( + "Experimental representation is not " + "supported for non-Kernel SHAP explainer." + ) + + if data.empty: + raise ValueError("The provided background data set is empty.") + + if self._computational_representation: + + def model(x): + tensor = to_tensor(x) + output = self.surrogate._posterior_comp(tensor).mean + + return output.detach().numpy() + else: + + def model(x): + df = pd.DataFrame(x, columns=data.columns) + output = self.surrogate.posterior(df).mean + + return output.detach().numpy() + + try: + shap_explainer = explainer_class(model, data, **kwargs) + except shap.utils._exceptions.InvalidModelError: + raise TypeError( + "The selected explainer class does not support the campaign surrogate." + ) + except TypeError as e: + if ( + "not supported for the input types, and the inputs could " + "not be safely coerced to any supported types" + in str(e) + and not self._computational_representation + ): + raise NotImplementedError( + "The selected explainer class does not support experimental " + "representation. Switch to computational representation " + "or use a different explainer " + "(e.g. the default shap.KernelExplainer)." + ) + return shap_explainer + + def _get_explanation( + self, + data: pd.DataFrame | None = None, + explainer_class: type[shap.Explainer] = shap.KernelExplainer, + ) -> shap.Explanation: + """Compute the Shapley values based on the chosen explainer and data set. + + Args: + data: The data set for which the Shapley values should be computed. + explainer_class: The explainer class to be used for the computation. + + Returns: + shap.Explanation: The computed Shapley values. + + Raises: + ValueError: If the provided data set does not have the same amount of + parameters as the SHAP explainer background + """ + if data is None: + data = self._bg_data + elif not self._bg_data.shape[1] == data.shape[1]: + raise ValueError( + "The provided data does not have the same amount of " + "parameters as the shap explainer background." + ) + + # Type checking for mypy + assert isinstance(data, pd.DataFrame) + + if not self._is_shap_explainer: + """Return attributions for non-SHAP explainers.""" + if explainer_class.__module__.endswith("maple"): + """Additional argument for maple to increase comparability to SHAP.""" + attributions = self.explainer.attributions( + data, multiply_by_input=True + )[0] + else: + attributions = self.explainer.attributions(data)[0] + explanations = shap.Explanation( + values=attributions, + base_values=data, + data=data, + feature_names=data.columns.values, + ) + return explanations + else: + explanations = self.explainer(data) + + """Ensure that the explanation object is of the correct dimensionality.""" + if len(explanations.shape) == 2: + return explanations + if len(explanations.shape) == 3: + return explanations[:, :, 0] + raise ValueError( + "The Explanation has an invalid " + f"dimensionality of {len(explanations.shape)}." + ) + + @property + def explanation(self) -> shap.Explanation: + """Get the SHAP explanation object. Uses lazy evaluation. + + Returns: + shap.Explanation: The SHAP explanation object. + """ + if self._explanation is None: + self._explanation = self._get_explanation() + + return self._explanation + + def plot(self, plot_type: str, **kwargs) -> None: + """Plot the Shapley values using the provided plot type. + + Args: + plot_type: The type of plot to be created. Supported types are: + "bar", "scatter", "heatmap", "force", "beeswarm". + **kwargs: Additional keyword arguments to be passed to the plot function. + + Raises: + ValueError: If the provided plot type is not supported + """ + if plot_type == "scatter": + self._plot_shap_scatter(**kwargs) + return None + + plot = getattr(shap.plots, plot_type, None) + if ( + plot is None + or not callable(plot) + or plot_type not in self.DEFAULT_SHAP_PLOTS + ): + raise ValueError(f"Invalid plot type: {plot_type}") + + plot(self.explanation, **kwargs) + + def _plot_shap_scatter(self) -> None: + """Plot the Shapley values as scatter plot while leaving out string values.""" + + def is_not_numeric_column(col): + return np.array([not isinstance(v, numbers.Number) for v in col]).any() + + if np.ndim(self._bg_data) == 1: + if is_not_numeric_column(self._bg_data): + warnings.warn( + "Cannot plot scatter plot for the provided " + "explanation as it contains non-numeric values." + ) + else: + shap.plots.scatter(self.explanation) + else: + # Type checking for mypy + assert isinstance(self._bg_data, pd.DataFrame) + + mask = self._bg_data.iloc[0].apply(lambda x: not isinstance(x, str)) + number_enum = np.where(mask)[0].tolist() + + if len(number_enum) < len(self._bg_data.iloc[0]): + warnings.warn( + "Cannot plot SHAP scatter plot for all " + "parameters as some contain non-numeric values." + ) + shap.plots.scatter(self.explanation[:, number_enum]) diff --git a/pyproject.toml b/pyproject.toml index 9d17da3da..29a3bc87c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ onnx = [ dev = [ "baybe[chem]", - "baybe[diagnostics]", + "baybe[insights]", "baybe[docs]", "baybe[examples]", "baybe[lint]", @@ -90,13 +90,13 @@ dev = [ "baybe[polars]", "baybe[simulation]", "baybe[test]", - "baybe[diagnostics]", + "baybe[insights]", "pip-audit>=2.5.5", "tox-uv>=1.7.0", "uv>=0.3.0", # `uv lock` (for lockfiles) is stable since 0.3.0: https://github.com/astral-sh/uv/issues/2679#event-13950215962 ] -diagnostics = [ +insights = [ "shap>=0.46.0", "lime>=0.2.0.1" ] diff --git a/pytest.ini b/pytest.ini index 36bc17480..2dc951c4d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,7 @@ addopts = --ignore=baybe/_optional --ignore=baybe/utils/chemistry.py --ignore=tests/simulate_telemetry.py - --ignore=baybe/diagnostics + --ignore=baybe/insights testpaths = baybe tests \ No newline at end of file diff --git a/tests/diagnostics/test_shap.py b/tests/diagnostics/test_shap.py deleted file mode 100644 index b07e75ad2..000000000 --- a/tests/diagnostics/test_shap.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Tests for diagnostic utilities.""" - -import inspect - -import pandas as pd -import pytest -from pytest import param - -from baybe._optional.info import DIAGNOSTICS_INSTALLED -from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender -from baybe.recommenders.pure.bayesian.base import BayesianRecommender -from baybe.searchspace import SearchSpaceType -from baybe.utils.basic import get_subclasses -from tests.conftest import run_iterations - -pytestmark = pytest.mark.skipif( - not DIAGNOSTICS_INSTALLED, reason="Optional diagnostics dependency not installed." -) - -if DIAGNOSTICS_INSTALLED: - import shap - - from baybe import diagnostics as diag - -EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] - - -def _has_required_init_parameters(cls): - """Helper function checks if initializer has required standard parameters.""" - REQUIRED_PARAMETERS = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == REQUIRED_PARAMETERS - - -non_shap_explainers = ( - [ - param(explainer, id=f"{cls_name}") - for cls_name in shap.explainers.other.__all__ - if _has_required_init_parameters( - explainer := getattr(shap.explainers.other, cls_name) - ) - and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - ] - if DIAGNOSTICS_INSTALLED - else [] -) - -shap_explainers = ( - [ - param(getattr(shap.explainers, cls_name), id=f"{cls_name}") - for cls_name in shap.explainers.__all__ - if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - ] - if DIAGNOSTICS_INSTALLED - else [] -) - -valid_hybrid_bayesian_recommenders = [ - param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") - for cls in get_subclasses(BayesianRecommender) - if cls.compatibility == SearchSpaceType.HYBRID -] - - -def _test_explainer(campaign, explainer_cls, use_comp_rep): - """Helper function for general explainer tests.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - try: - shap_val = diag.explanation( - campaign, - computational_representation=use_comp_rep, - explainer_class=explainer_cls, - ) - assert isinstance(shap_val, shap.Explanation) - df = pd.DataFrame({"Num_disc_1": [0, 2]}) - with pytest.raises( - ValueError, - match="The provided data does not have the same " - "amount of parameters as specified for the campaign.", - ): - diag.explanation( - campaign, - data=df, - computational_representation=True, - explainer_class=explainer_cls, - ) - except TypeError as e: - if ( - "The selected explainer class does not support the campaign surrogate." - in str(e) - ): - pass - else: - raise e - - -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], -) -def test_shapley_values_no_measurements(campaign): - """A campaign without measurements raises an error.""" - with pytest.raises(ValueError, match="No measurements have been provided yet."): - diag.explanation(campaign) - - -@pytest.mark.slow -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("explainer_cls", shap_explainers) -@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]]) -def test_shapley_with_measurements_continuous(campaign, explainer_cls, use_comp_rep): - """Test the explain functionalities with measurements.""" - _test_explainer(campaign, explainer_cls, use_comp_rep) - - -@pytest.mark.slow -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("explainer_cls", shap_explainers) -@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], -) -def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): - """Test the explain functionalities with measurements.""" - _test_explainer(campaign, explainer_cls, use_comp_rep) - - -@pytest.mark.parametrize("explainer_cls", non_shap_explainers) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], - ids=["params1"], -) -def test_non_shapley_explainers(campaign, explainer_cls): - """Test the explain functionalities with the non-SHAP explainer MAPLE.""" - """Test the non-SHAP explainer in computational representation.""" - _test_explainer(campaign, explainer_cls, use_comp_rep=True) - """Ensure that an error is raised if non-computational representation - is used with a non-Kernel SHAP explainer.""" - with pytest.raises( - ValueError, - match=( - "Experimental representation is not supported " - "for non-Kernel SHAP explainer." - ), - ): - diag.explanation( - campaign, - computational_representation=False, - explainer_class=explainer_cls, - ) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py new file mode 100644 index 000000000..9624ae8c8 --- /dev/null +++ b/tests/insights/test_shap.py @@ -0,0 +1,198 @@ +"""Tests for insights subpackage.""" + +import inspect +from unittest import mock + +import pandas as pd +import pytest +from pytest import param + +from baybe._optional.info import INSIGHTS_INSTALLED +from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.searchspace import SearchSpaceType +from baybe.utils.basic import get_subclasses +from tests.conftest import run_iterations + +pytestmark = pytest.mark.skipif( + not INSIGHTS_INSTALLED, reason="Optional insights dependency not installed." +) + +if INSIGHTS_INSTALLED: + import shap + + from baybe import insights + from baybe.insights.shap import SHAPInsight + + DEFAULT_SHAP_PLOTS = insights.SHAPInsight.DEFAULT_SHAP_PLOTS +else: + DEFAULT_SHAP_PLOTS = [] + +EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] + + +def _has_required_init_parameters(cls): + """Helper function checks if initializer has required standard parameters.""" + REQUIRED_PARAMETERS = ["self", "model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == REQUIRED_PARAMETERS + + +non_shap_explainers = ( + [ + param(explainer, id=f"{cls_name}") + for cls_name in shap.explainers.other.__all__ + if _has_required_init_parameters( + explainer := getattr(shap.explainers.other, cls_name) + ) + and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + ] + if INSIGHTS_INSTALLED + else [] +) + +shap_explainers = ( + [ + param(getattr(shap.explainers, cls_name), id=f"{cls_name}") + for cls_name in shap.explainers.__all__ + if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + ] + if INSIGHTS_INSTALLED + else [] +) + +valid_hybrid_bayesian_recommenders = [ + param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") + for cls in get_subclasses(BayesianRecommender) + if cls.compatibility == SearchSpaceType.HYBRID +] + + +def _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap): + """Helper function for general SHAP explainer tests.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + try: + shap_insights = SHAPInsight.from_campaign( + campaign, + explainer_class=explainer_cls, + computational_representation=use_comp_rep, + ) + assert isinstance(shap_insights, insights.SHAPInsight) + assert isinstance(shap_insights.explainer, explainer_cls) + assert shap_insights._is_shap_explainer == is_shap + shap_explanation = shap_insights.explanation + assert isinstance(shap_explanation, shap.Explanation) + df = pd.DataFrame({"Num_disc_1": [0, 2]}) + with pytest.raises( + ValueError, + match="The provided data does not have the same " + "amount of parameters as the shap explainer background.", + ): + shap_insights = SHAPInsight.from_campaign( + campaign, + explainer_class=explainer_cls, + explained_data=df, + ) + except TypeError as e: + if ( + "The selected explainer class does not support the campaign surrogate." + in str(e) + ): + pass + except NotImplementedError as e: + if ( + "The selected explainer class does not support experimental " + "representation. Switch to computational representation or " + "use a different explainer (e.g. the default " + "shap.KernelExplainer)." + in str(e) + and not use_comp_rep + and not isinstance(explainer_cls, shap.explainers.KernelExplainer) + ): + pass + + +@pytest.mark.slow +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize("explainer_cls", shap_explainers) +@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) +@pytest.mark.parametrize( + "parameter_names", + [ + ["Conti_finite1", "Conti_finite2"], + ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], + ], + ids=["continuous_params", "hybrid_params"], +) +def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): + """Test the explain functionalities with measurements.""" + _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap=True) + + +@pytest.mark.parametrize("explainer_cls", non_shap_explainers) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], + ids=["hybrid_params"], +) +def test_non_shapley_explainers(campaign, explainer_cls): + """Test the explain functionalities with the non-SHAP explainer MAPLE.""" + """Test the non-SHAP explainer in computational representation.""" + _test_shap_insights(campaign, explainer_cls, use_comp_rep=True, is_shap=False) + + +@pytest.mark.slow +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) +@pytest.mark.parametrize("plot_type", DEFAULT_SHAP_PLOTS) +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], + ids=["hybrid_params"], +) +def test_shap_insight_plots(campaign, use_comp_rep, plot_type): + """Test the default SHAP plots.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + shap_insights = SHAPInsight.from_campaign( + campaign, + computational_representation=use_comp_rep, + ) + with mock.patch("matplotlib.pyplot.show"): + shap_insights.plot(plot_type) + + +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +def test_updated_campaign_explanations(campaign): + """Test explanations for campaigns with updated measurements.""" + with pytest.raises( + ValueError, + match="The campaign does not contain any measurements.", + ): + shap_insights = SHAPInsight.from_campaign(campaign) + run_iterations(campaign, n_iterations=2, batch_size=1) + shap_insights = SHAPInsight.from_campaign(campaign) + explanation_two_iter = shap_insights.explanation + run_iterations(campaign, n_iterations=2, batch_size=1) + shap_insights = SHAPInsight.from_campaign(campaign) + explanation_four_iter = shap_insights.explanation + assert explanation_two_iter != explanation_four_iter + + +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +def test_shap_insights_from_recommender(campaign): + """Test the creation of SHAP insights from a recommender.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + recommender = campaign.recommender.recommender + shap_insight = SHAPInsight.from_recommender( + recommender, + campaign.searchspace, + campaign.objective, + campaign.measurements, + ) + assert isinstance(shap_insight, insights.SHAPInsight) diff --git a/tox.ini b/tox.ini index 1f90161de..ea5c7b12d 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ isolated_build = True [testenv:fulltest,fulltest-py{310,311,312}] description = Run PyTest with all extra functionality -extras = chem,examples,lint,onnx,polars,diagnostics,simulation,test +extras = chem,examples,lint,onnx,polars,insights,simulation,test passenv = CI BAYBE_NUMPY_USE_SINGLE_PRECISION From 3fef7b35b5c89ebaecf2a37dbc31c435672934c1 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Tue, 10 Dec 2024 18:37:34 +0100 Subject: [PATCH 21/92] Moved explainer maps from testing of shap functionality to shap, allowing explainers to be specified via strings. General cleanup for PR. --- baybe/insights/shap.py | 98 ++++++++++++++++++++++++++++--------- tests/insights/test_shap.py | 60 ++++++----------------- 2 files changed, 92 insertions(+), 66 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index d431b93ae..4f51d6548 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -1,11 +1,12 @@ """SHAP utilities.""" +import inspect import numbers import warnings +from typing import Any, override import numpy as np import pandas as pd -from typing_extensions import override from baybe import Campaign from baybe._optional.insights import shap @@ -19,30 +20,76 @@ class SHAPInsight(Insight): """Base class for all SHAP insights.""" - DEFAULT_SHAP_PLOTS = [ + DEFAULT_SHAP_PLOTS = { "bar", "scatter", "heatmap", "force", "beeswarm", - ] + } + + @staticmethod + def _get_explainer_maps() -> ( + tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]] + ): + """Get explainer maps for SHAP and non-SHAP explainers. + + Returns: + tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]]: + The explainer maps for SHAP and non-SHAP explainers. + """ + EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] + + def _has_required_init_parameters(cls): + """Check if non-shap initializer has required standard parameters.""" + REQUIRED_PARAMETERS = ["self", "model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == REQUIRED_PARAMETERS + + shap_explainers = { + cls_name: getattr(shap.explainers, cls_name) + for cls_name in shap.explainers.__all__ + if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + } + + non_shap_explainers = { + cls_name: explainer + for cls_name in shap.explainers.other.__all__ + if _has_required_init_parameters( + explainer := getattr(shap.explainers.other, cls_name) + ) + and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + } + + return shap_explainers, non_shap_explainers + + SHAP_EXPLAINERS, NON_SHAP_EXPLAINERS = _get_explainer_maps() + + ALL_EXPLAINERS = {**SHAP_EXPLAINERS, **NON_SHAP_EXPLAINERS} def __init__( self, surrogate_model, bg_data: pd.DataFrame, explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer = shap.KernelExplainer, + explainer_class: shap.Explainer | str = "KernelExplainer", computational_representation: bool = False, ): super().__init__(surrogate_model) self._computational_representation = computational_representation - self._is_shap_explainer = not explainer_class.__module__.startswith( + explainer_cls = ( + explainer_class + if not isinstance(explainer_class, str) + or explainer_class not in self.ALL_EXPLAINERS + else self.ALL_EXPLAINERS[explainer_class] + ) + self._is_shap_explainer = not explainer_cls.__module__.startswith( "shap.explainers.other." ) self._bg_data = bg_data self._explained_data = explained_data - self.explainer = self._get_explainer(bg_data, explainer_class) + self.explainer = self._init_explainer(bg_data, explainer_cls) # type: ignore[arg-type] self._explanation = None @override @@ -50,13 +97,16 @@ def __init__( def from_campaign( cls, campaign: Campaign, - explainer_class: shap.Explainer = shap.KernelExplainer, + explained_data: pd.DataFrame | None = None, + explainer_class: shap.Explainer | str = "KernelExplainer", computational_representation: bool = False, ): """Create a SHAP insight from a campaign. Args: campaign: The campaign to be used for the SHAP insight. + explained_data: The data set to be explained. If None, + all measurements from the campaign are used. explainer_class: The explainer class to be used for the computation. computational_representation: Whether to use the computational representation. @@ -77,6 +127,7 @@ def from_campaign( else data, explainer_class=explainer_class, computational_representation=computational_representation, + explained_data=explained_data, ) @override @@ -88,7 +139,7 @@ def from_recommender( objective: Objective, bg_data: pd.DataFrame, explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer = shap.KernelExplainer, + explainer_class: shap.Explainer | str = "KernelExplainer", computational_representation: bool = False, ): """Create a SHAP insight from a recommender. @@ -127,7 +178,7 @@ def from_recommender( computational_representation=computational_representation, ) - def _get_explainer( + def _init_explainer( self, data: pd.DataFrame, explainer_class: type[shap.Explainer] = shap.KernelExplainer, @@ -176,35 +227,34 @@ def model(x): try: shap_explainer = explainer_class(model, data, **kwargs) + """Explain first two data points to ensure that the explainer is working.""" + if self._is_shap_explainer: + shap_explainer(self._bg_data.iloc[0:1]) except shap.utils._exceptions.InvalidModelError: raise TypeError( "The selected explainer class does not support the campaign surrogate." ) except TypeError as e: if ( - "not supported for the input types, and the inputs could " - "not be safely coerced to any supported types" - in str(e) + "not supported for the input types" in str(e) and not self._computational_representation ): raise NotImplementedError( "The selected explainer class does not support experimental " - "representation. Switch to computational representation " + "representation. Switch to computational representation " "or use a different explainer " "(e.g. the default shap.KernelExplainer)." ) return shap_explainer - def _get_explanation( + def _init_explanation( self, data: pd.DataFrame | None = None, - explainer_class: type[shap.Explainer] = shap.KernelExplainer, ) -> shap.Explanation: """Compute the Shapley values based on the chosen explainer and data set. Args: data: The data set for which the Shapley values should be computed. - explainer_class: The explainer class to be used for the computation. Returns: shap.Explanation: The computed Shapley values. @@ -226,7 +276,7 @@ def _get_explanation( if not self._is_shap_explainer: """Return attributions for non-SHAP explainers.""" - if explainer_class.__module__.endswith("maple"): + if self.explainer.__module__.endswith("maple"): """Additional argument for maple to increase comparability to SHAP.""" attributions = self.explainer.attributions( data, multiply_by_input=True @@ -235,7 +285,7 @@ def _get_explanation( attributions = self.explainer.attributions(data)[0] explanations = shap.Explanation( values=attributions, - base_values=data, + base_values=self.explainer.model(self._bg_data).mean(), data=data, feature_names=data.columns.values, ) @@ -261,7 +311,7 @@ def explanation(self) -> shap.Explanation: shap.Explanation: The SHAP explanation object. """ if self._explanation is None: - self._explanation = self._get_explanation() + self._explanation = self._init_explanation() return self._explanation @@ -290,8 +340,12 @@ def plot(self, plot_type: str, **kwargs) -> None: plot(self.explanation, **kwargs) - def _plot_shap_scatter(self) -> None: - """Plot the Shapley values as scatter plot while leaving out string values.""" + def _plot_shap_scatter(self, **kwargs: Any) -> None: + """Plot the Shapley values as scatter plot while leaving out string values. + + Args: + **kwargs: Additional keyword arguments to be passed to the plot function. + """ def is_not_numeric_column(col): return np.array([not isinstance(v, numbers.Number) for v in col]).any() @@ -316,4 +370,4 @@ def is_not_numeric_column(col): "Cannot plot SHAP scatter plot for all " "parameters as some contain non-numeric values." ) - shap.plots.scatter(self.explanation[:, number_enum]) + shap.plots.scatter(self.explanation[:, number_enum], **kwargs) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 9624ae8c8..d3810056d 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -1,6 +1,5 @@ """Tests for insights subpackage.""" -import inspect from unittest import mock import pandas as pd @@ -24,43 +23,13 @@ from baybe import insights from baybe.insights.shap import SHAPInsight - DEFAULT_SHAP_PLOTS = insights.SHAPInsight.DEFAULT_SHAP_PLOTS + default_shap_plots = SHAPInsight.DEFAULT_SHAP_PLOTS + shap_explainers, non_shap_explainers = ( + SHAPInsight.SHAP_EXPLAINERS, + SHAPInsight.NON_SHAP_EXPLAINERS, + ) else: - DEFAULT_SHAP_PLOTS = [] - -EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] - - -def _has_required_init_parameters(cls): - """Helper function checks if initializer has required standard parameters.""" - REQUIRED_PARAMETERS = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == REQUIRED_PARAMETERS - - -non_shap_explainers = ( - [ - param(explainer, id=f"{cls_name}") - for cls_name in shap.explainers.other.__all__ - if _has_required_init_parameters( - explainer := getattr(shap.explainers.other, cls_name) - ) - and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - ] - if INSIGHTS_INSTALLED - else [] -) - -shap_explainers = ( - [ - param(getattr(shap.explainers, cls_name), id=f"{cls_name}") - for cls_name in shap.explainers.__all__ - if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - ] - if INSIGHTS_INSTALLED - else [] -) + default_shap_plots, shap_explainers, non_shap_explainers = [], [], [] valid_hybrid_bayesian_recommenders = [ param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") @@ -79,7 +48,10 @@ def _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap): computational_representation=use_comp_rep, ) assert isinstance(shap_insights, insights.SHAPInsight) - assert isinstance(shap_insights.explainer, explainer_cls) + assert isinstance( + shap_insights.explainer, + SHAPInsight.ALL_EXPLAINERS[explainer_cls], + ) assert shap_insights._is_shap_explainer == is_shap shap_explanation = shap_insights.explanation assert isinstance(shap_explanation, shap.Explanation) @@ -89,17 +61,15 @@ def _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap): match="The provided data does not have the same " "amount of parameters as the shap explainer background.", ): - shap_insights = SHAPInsight.from_campaign( - campaign, - explainer_class=explainer_cls, - explained_data=df, - ) + shap_insights._init_explanation(df) except TypeError as e: if ( "The selected explainer class does not support the campaign surrogate." in str(e) ): pass + else: + raise e except NotImplementedError as e: if ( "The selected explainer class does not support experimental " @@ -111,6 +81,8 @@ def _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap): and not isinstance(explainer_cls, shap.explainers.KernelExplainer) ): pass + else: + raise e @pytest.mark.slow @@ -148,7 +120,7 @@ def test_non_shapley_explainers(campaign, explainer_cls): @pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) @pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) @pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize("plot_type", DEFAULT_SHAP_PLOTS) +@pytest.mark.parametrize("plot_type", default_shap_plots) @pytest.mark.parametrize( "parameter_names", [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], From 949629f9ff874a807906545440d6be160075e541 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Tue, 10 Dec 2024 19:22:10 +0100 Subject: [PATCH 22/92] Cleanup of CONTRIBUTORS.md --- CONTRIBUTORS.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c32ff6a38..0b7f5c4a6 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -30,6 +30,5 @@ `scikit-fingerprints` support - Fabian Liebig (Merck KGaA, Darmstadt, Germany):\ Benchmarking structure and persistence capabilities for benchmarking results - - Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dรผbendorf, Switzerland):\ SHAP explainers for insights \ No newline at end of file From 36cb30c076a2b53c437aece8454d2e58265a0141 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 27 Dec 2024 20:42:14 +0100 Subject: [PATCH 23/92] Minor reformatting --- baybe/insights/shap.py | 751 ++++++++++++++++++------------------ tests/insights/test_shap.py | 19 +- 2 files changed, 384 insertions(+), 386 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 4f51d6548..24991b5b9 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -1,373 +1,378 @@ -"""SHAP utilities.""" - -import inspect -import numbers -import warnings -from typing import Any, override - -import numpy as np -import pandas as pd - -from baybe import Campaign -from baybe._optional.insights import shap -from baybe.insights.base import Insight -from baybe.objectives.base import Objective -from baybe.recommenders.pure.bayesian.base import BayesianRecommender -from baybe.searchspace import SearchSpace -from baybe.utils.dataframe import to_tensor - - -class SHAPInsight(Insight): - """Base class for all SHAP insights.""" - - DEFAULT_SHAP_PLOTS = { - "bar", - "scatter", - "heatmap", - "force", - "beeswarm", - } - - @staticmethod - def _get_explainer_maps() -> ( - tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]] - ): - """Get explainer maps for SHAP and non-SHAP explainers. - - Returns: - tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]]: - The explainer maps for SHAP and non-SHAP explainers. - """ - EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] - - def _has_required_init_parameters(cls): - """Check if non-shap initializer has required standard parameters.""" - REQUIRED_PARAMETERS = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == REQUIRED_PARAMETERS - - shap_explainers = { - cls_name: getattr(shap.explainers, cls_name) - for cls_name in shap.explainers.__all__ - if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - } - - non_shap_explainers = { - cls_name: explainer - for cls_name in shap.explainers.other.__all__ - if _has_required_init_parameters( - explainer := getattr(shap.explainers.other, cls_name) - ) - and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - } - - return shap_explainers, non_shap_explainers - - SHAP_EXPLAINERS, NON_SHAP_EXPLAINERS = _get_explainer_maps() - - ALL_EXPLAINERS = {**SHAP_EXPLAINERS, **NON_SHAP_EXPLAINERS} - - def __init__( - self, - surrogate_model, - bg_data: pd.DataFrame, - explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer | str = "KernelExplainer", - computational_representation: bool = False, - ): - super().__init__(surrogate_model) - self._computational_representation = computational_representation - explainer_cls = ( - explainer_class - if not isinstance(explainer_class, str) - or explainer_class not in self.ALL_EXPLAINERS - else self.ALL_EXPLAINERS[explainer_class] - ) - self._is_shap_explainer = not explainer_cls.__module__.startswith( - "shap.explainers.other." - ) - self._bg_data = bg_data - self._explained_data = explained_data - self.explainer = self._init_explainer(bg_data, explainer_cls) # type: ignore[arg-type] - self._explanation = None - - @override - @classmethod - def from_campaign( - cls, - campaign: Campaign, - explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer | str = "KernelExplainer", - computational_representation: bool = False, - ): - """Create a SHAP insight from a campaign. - - Args: - campaign: The campaign to be used for the SHAP insight. - explained_data: The data set to be explained. If None, - all measurements from the campaign are used. - explainer_class: The explainer class to be used for the computation. - computational_representation: - Whether to use the computational representation. - - Returns: - SHAPInsight: The SHAP insight object. - - Raises: - ValueError: If the campaign does not contain any measurements. - """ - if campaign.measurements.empty: - raise ValueError("The campaign does not contain any measurements.") - data = campaign.measurements[[p.name for p in campaign.parameters]].copy() - return cls( - campaign.get_surrogate(), - bg_data=campaign.searchspace.transform(data) - if computational_representation - else data, - explainer_class=explainer_class, - computational_representation=computational_representation, - explained_data=explained_data, - ) - - @override - @classmethod - def from_recommender( - cls, - recommender: BayesianRecommender, - searchspace: SearchSpace, - objective: Objective, - bg_data: pd.DataFrame, - explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer | str = "KernelExplainer", - computational_representation: bool = False, - ): - """Create a SHAP insight from a recommender. - - Args: - recommender: The recommender to be used for the SHAP insight. - searchspace: The searchspace for the recommender. - objective: The objective for the recommender. - bg_data: The background data set for Explainer. - This is also the measurement data set for the recommender. - explained_data: The data set to be explained. If None, - the background data set is used. - explainer_class: The explainer class. - computational_representation: - Whether to use the computational representation. - - Returns: - SHAPInsight: The SHAP insight object. - - Raises: - ValueError: If the recommender has not implemented a "get_surrogate" method. - """ - if not hasattr(recommender, "get_surrogate"): - raise ValueError( - "The provided recommender does not provide a surrogate model." - ) - surrogate_model = recommender.get_surrogate(searchspace, objective, bg_data) - - return cls( - surrogate_model, - bg_data=searchspace.transform(bg_data) - if computational_representation - else bg_data, - explained_data=explained_data, - explainer_class=explainer_class, - computational_representation=computational_representation, - ) - - def _init_explainer( - self, - data: pd.DataFrame, - explainer_class: type[shap.Explainer] = shap.KernelExplainer, - **kwargs, - ) -> shap.Explainer: - """Create an explainer for the provided campaign. - - Args: - data: The background data set. - explainer_class: The explainer class to be used. - **kwargs: Additional keyword arguments to be passed to the explainer. - - Returns: - shap.Explainer: The created explainer object. - - Raises: - NotImplementedError: If the provided explainer class does - not support the experimental representation. - ValueError: If the provided background data set is empty. - TypeError: If the provided explainer class does not - support the campaign surrogate. - """ - if not self._is_shap_explainer and not self._computational_representation: - raise NotImplementedError( - "Experimental representation is not " - "supported for non-Kernel SHAP explainer." - ) - - if data.empty: - raise ValueError("The provided background data set is empty.") - - if self._computational_representation: - - def model(x): - tensor = to_tensor(x) - output = self.surrogate._posterior_comp(tensor).mean - - return output.detach().numpy() - else: - - def model(x): - df = pd.DataFrame(x, columns=data.columns) - output = self.surrogate.posterior(df).mean - - return output.detach().numpy() - - try: - shap_explainer = explainer_class(model, data, **kwargs) - """Explain first two data points to ensure that the explainer is working.""" - if self._is_shap_explainer: - shap_explainer(self._bg_data.iloc[0:1]) - except shap.utils._exceptions.InvalidModelError: - raise TypeError( - "The selected explainer class does not support the campaign surrogate." - ) - except TypeError as e: - if ( - "not supported for the input types" in str(e) - and not self._computational_representation - ): - raise NotImplementedError( - "The selected explainer class does not support experimental " - "representation. Switch to computational representation " - "or use a different explainer " - "(e.g. the default shap.KernelExplainer)." - ) - return shap_explainer - - def _init_explanation( - self, - data: pd.DataFrame | None = None, - ) -> shap.Explanation: - """Compute the Shapley values based on the chosen explainer and data set. - - Args: - data: The data set for which the Shapley values should be computed. - - Returns: - shap.Explanation: The computed Shapley values. - - Raises: - ValueError: If the provided data set does not have the same amount of - parameters as the SHAP explainer background - """ - if data is None: - data = self._bg_data - elif not self._bg_data.shape[1] == data.shape[1]: - raise ValueError( - "The provided data does not have the same amount of " - "parameters as the shap explainer background." - ) - - # Type checking for mypy - assert isinstance(data, pd.DataFrame) - - if not self._is_shap_explainer: - """Return attributions for non-SHAP explainers.""" - if self.explainer.__module__.endswith("maple"): - """Additional argument for maple to increase comparability to SHAP.""" - attributions = self.explainer.attributions( - data, multiply_by_input=True - )[0] - else: - attributions = self.explainer.attributions(data)[0] - explanations = shap.Explanation( - values=attributions, - base_values=self.explainer.model(self._bg_data).mean(), - data=data, - feature_names=data.columns.values, - ) - return explanations - else: - explanations = self.explainer(data) - - """Ensure that the explanation object is of the correct dimensionality.""" - if len(explanations.shape) == 2: - return explanations - if len(explanations.shape) == 3: - return explanations[:, :, 0] - raise ValueError( - "The Explanation has an invalid " - f"dimensionality of {len(explanations.shape)}." - ) - - @property - def explanation(self) -> shap.Explanation: - """Get the SHAP explanation object. Uses lazy evaluation. - - Returns: - shap.Explanation: The SHAP explanation object. - """ - if self._explanation is None: - self._explanation = self._init_explanation() - - return self._explanation - - def plot(self, plot_type: str, **kwargs) -> None: - """Plot the Shapley values using the provided plot type. - - Args: - plot_type: The type of plot to be created. Supported types are: - "bar", "scatter", "heatmap", "force", "beeswarm". - **kwargs: Additional keyword arguments to be passed to the plot function. - - Raises: - ValueError: If the provided plot type is not supported - """ - if plot_type == "scatter": - self._plot_shap_scatter(**kwargs) - return None - - plot = getattr(shap.plots, plot_type, None) - if ( - plot is None - or not callable(plot) - or plot_type not in self.DEFAULT_SHAP_PLOTS - ): - raise ValueError(f"Invalid plot type: {plot_type}") - - plot(self.explanation, **kwargs) - - def _plot_shap_scatter(self, **kwargs: Any) -> None: - """Plot the Shapley values as scatter plot while leaving out string values. - - Args: - **kwargs: Additional keyword arguments to be passed to the plot function. - """ - - def is_not_numeric_column(col): - return np.array([not isinstance(v, numbers.Number) for v in col]).any() - - if np.ndim(self._bg_data) == 1: - if is_not_numeric_column(self._bg_data): - warnings.warn( - "Cannot plot scatter plot for the provided " - "explanation as it contains non-numeric values." - ) - else: - shap.plots.scatter(self.explanation) - else: - # Type checking for mypy - assert isinstance(self._bg_data, pd.DataFrame) - - mask = self._bg_data.iloc[0].apply(lambda x: not isinstance(x, str)) - number_enum = np.where(mask)[0].tolist() - - if len(number_enum) < len(self._bg_data.iloc[0]): - warnings.warn( - "Cannot plot SHAP scatter plot for all " - "parameters as some contain non-numeric values." - ) - shap.plots.scatter(self.explanation[:, number_enum], **kwargs) +"""SHAP insights.""" + +import inspect +import numbers +import warnings +from typing import Any + +import numpy as np +import pandas as pd +from typing_extensions import override + +from baybe import Campaign +from baybe._optional.insights import shap +from baybe.insights.base import Insight +from baybe.objectives.base import Objective +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.searchspace import SearchSpace +from baybe.utils.dataframe import to_tensor + + +class SHAPInsight(Insight): + """Base class for all SHAP insights.""" + + DEFAULT_SHAP_PLOTS = { + "bar", + "scatter", + "heatmap", + "force", + "beeswarm", + } + + @staticmethod + def _get_explainer_maps() -> ( + tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]] + ): + """Get explainer maps for SHAP and non-SHAP explainers. + + Returns: + The explainer maps for SHAP and non-SHAP explainers. + """ + EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] + + def _has_required_init_parameters(cls): + """Check if non-shap initializer has required standard parameters.""" + REQUIRED_PARAMETERS = ["self", "model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == REQUIRED_PARAMETERS + + shap_explainers = { + cls_name: getattr(shap.explainers, cls_name) + for cls_name in shap.explainers.__all__ + if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + } + + non_shap_explainers = { + cls_name: explainer + for cls_name in shap.explainers.other.__all__ + if _has_required_init_parameters( + explainer := getattr(shap.explainers.other, cls_name) + ) + and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + } + + return shap_explainers, non_shap_explainers + + SHAP_EXPLAINERS, NON_SHAP_EXPLAINERS = _get_explainer_maps() + + ALL_EXPLAINERS = {**SHAP_EXPLAINERS, **NON_SHAP_EXPLAINERS} + + def __init__( + self, + surrogate_model, + bg_data: pd.DataFrame, + explained_data: pd.DataFrame | None = None, + explainer_class: shap.Explainer | str = "KernelExplainer", + use_comp_rep: bool = False, + ): + super().__init__(surrogate_model) + self._use_comp_rep = use_comp_rep + explainer_cls = ( + explainer_class + if not isinstance(explainer_class, str) + or explainer_class not in self.ALL_EXPLAINERS + else self.ALL_EXPLAINERS[explainer_class] + ) + self._is_shap_explainer = not explainer_cls.__module__.startswith( + "shap.explainers.other." + ) + self._bg_data = bg_data + self._explained_data = explained_data + self.explainer = self._init_explainer(bg_data, explainer_cls) # type: ignore[arg-type] + self._explanation = None + + @override + @classmethod + def from_campaign( + cls, + campaign: Campaign, + explained_data: pd.DataFrame | None = None, + explainer_class: shap.Explainer | str = "KernelExplainer", + use_comp_rep: bool = False, + ): + """Create a SHAP insight from a campaign. + + Args: + campaign: The campaign to be used for the SHAP insight. + explained_data: The data set to be explained. If None, + all measurements from the campaign are used. + explainer_class: The explainer class to be used for the computation. + use_comp_rep: + Whether to analyze the model in computational representation + (experimental representation otherwise). + + Returns: + The SHAP insight object. + + Raises: + ValueError: If the campaign does not contain any measurements. + """ + if campaign.measurements.empty: + raise ValueError( + f"The campaign does not contain any measurements. A {cls.__name__} " + f"assumes there is mandatory background data in the form of " + f"measurements as part of the campaign." + ) + data = campaign.measurements[[p.name for p in campaign.parameters]].copy() + + return cls( + campaign.get_surrogate(), + bg_data=campaign.searchspace.transform(data) if use_comp_rep else data, + explainer_class=explainer_class, + use_comp_rep=use_comp_rep, + explained_data=explained_data, + ) + + @override + @classmethod + def from_recommender( + cls, + recommender: BayesianRecommender, + searchspace: SearchSpace, + objective: Objective, + bg_data: pd.DataFrame, + explained_data: pd.DataFrame | None = None, + explainer_class: shap.Explainer | str = "KernelExplainer", + use_comp_rep: bool = False, + ): + """Create a SHAP insight from a recommender. + + Args: + recommender: The recommender to be used for the SHAP insight. + searchspace: The searchspace for the recommender. + objective: The objective for the recommender. + bg_data: The background data set for Explainer. + This is also the measurement data set for the recommender. + explained_data: The data set to be explained. If None, + the background data set is used. + explainer_class: The explainer class. + use_comp_rep: + Whether to analyze the model in computational representation + (experimental representation otherwise). + + Returns: + The SHAP insight object. + + Raises: + ValueError: If the recommender has not implemented a "get_surrogate" method. + """ + if not hasattr(recommender, "get_surrogate"): + raise ValueError( + "The provided recommender does not provide a surrogate model." + ) + surrogate_model = recommender.get_surrogate(searchspace, objective, bg_data) + + return cls( + surrogate_model, + bg_data=searchspace.transform(bg_data) if use_comp_rep else bg_data, + explained_data=explained_data, + explainer_class=explainer_class, + use_comp_rep=use_comp_rep, + ) + + def _init_explainer( + self, + bg_data: pd.DataFrame, + explainer_class: type[shap.Explainer] = shap.KernelExplainer, + **kwargs, + ) -> shap.Explainer: + """Create a SHAP explainer. + + Args: + bg_data: The background data set. + explainer_class: The explainer class to be used. + **kwargs: Additional keyword arguments to be passed to the explainer. + + Returns: + shap.Explainer: The created explainer object. + + Raises: + NotImplementedError: If the provided explainer class does + not support the experimental representation. + ValueError: If the provided background data set is empty. + TypeError: If the provided explainer class does not + support the campaign surrogate. + """ + if not self._is_shap_explainer and not self._use_comp_rep: + raise NotImplementedError( + "Experimental representation is not " + "supported for non-Kernel SHAP explainer." + ) + + if bg_data.empty: + raise ValueError("The provided background data set is empty.") + + if self._use_comp_rep: + + def model(x): + tensor = to_tensor(x) + output = self.surrogate._posterior_comp(tensor).mean + + return output.detach().numpy() + else: + + def model(x): + df = pd.DataFrame(x, columns=bg_data.columns) + output = self.surrogate.posterior(df).mean + + return output.detach().numpy() + + try: + shap_explainer = explainer_class(model, bg_data, **kwargs) + """Explain first two data points to ensure that the explainer is working.""" + if self._is_shap_explainer: + shap_explainer(self._bg_data.iloc[0:1]) + except shap.utils._exceptions.InvalidModelError: + raise TypeError( + f"The selected explainer class {explainer_class} does not support the " + f"provided surrogate model." + ) + except TypeError as e: + if "not supported for the input types" in str(e) and not self._use_comp_rep: + raise NotImplementedError( + f"The selected explainer class {explainer_class} does not support " + f"the experimental representation. Switch to computational " + f"representation or use a different explainer (e.g. the default " + f"shap.KernelExplainer)." + ) + else: + raise e + return shap_explainer + + def _init_explanation( + self, + explained_data: pd.DataFrame | None = None, + ) -> shap.Explanation: + """Compute the Shapley values based on the chosen explainer and data set. + + Args: + explained_data: The data set for which the Shapley values should be + computed. + + Returns: + shap.Explanation: The computed Shapley values. + + Raises: + ValueError: If the provided data set does not have the same amount of + parameters as the SHAP explainer background + """ + if explained_data is None: + explained_data = self._bg_data + elif not self._bg_data.shape[1] == explained_data.shape[1]: + raise ValueError( + "The provided data does not have the same amount of " + "parameters as the shap explainer background." + ) + + # Type checking for mypy + assert isinstance(explained_data, pd.DataFrame) + + if not self._is_shap_explainer: + # Return attributions for non-SHAP explainers + if self.explainer.__module__.endswith("maple"): + # Additional argument for maple to increase comparability to SHAP + attributions = self.explainer.attributions( + explained_data, multiply_by_input=True + )[0] + else: + attributions = self.explainer.attributions(explained_data)[0] + + explanations = shap.Explanation( + values=attributions, + base_values=self.explainer.model(self._bg_data).mean(), + data=explained_data, + feature_names=explained_data.columns.values, + ) + return explanations + else: + explanations = self.explainer(explained_data) + + """Ensure that the explanation object is of the correct dimensionality.""" + if len(explanations.shape) == 2: + return explanations + if len(explanations.shape) == 3: + return explanations[:, :, 0] + raise ValueError( + "The Explanation has an invalid " + f"dimensionality of {len(explanations.shape)}." + ) + + @property + def explanation(self) -> shap.Explanation: + """Get the SHAP explanation object. Uses lazy evaluation. + + Returns: + shap.Explanation: The SHAP explanation object. + """ + if self._explanation is None: + self._explanation = self._init_explanation() + + return self._explanation + + def plot(self, plot_type: str, **kwargs) -> None: + """Plot the Shapley values using the provided plot type. + + Args: + plot_type: The type of plot to be created. Supported types are: + "bar", "scatter", "heatmap", "force", "beeswarm". + **kwargs: Additional keyword arguments to be passed to the plot function. + + Raises: + ValueError: If the provided plot type is not supported + """ + if plot_type == "scatter": + self._plot_shap_scatter(**kwargs) + return None + + plot = getattr(shap.plots, plot_type, None) + if ( + plot is None + or not callable(plot) + or plot_type not in self.DEFAULT_SHAP_PLOTS + ): + raise ValueError(f"Invalid plot type: {plot_type}") + + plot(self.explanation, **kwargs) + + def _plot_shap_scatter(self, **kwargs: Any) -> None: + """Plot the Shapley values as scatter plot while leaving out string values. + + Args: + **kwargs: Additional keyword arguments to be passed to the plot function. + """ + + def is_not_numeric_column(col): + return np.array([not isinstance(v, numbers.Number) for v in col]).any() + + if np.ndim(self._bg_data) == 1: + if is_not_numeric_column(self._bg_data): + warnings.warn( + "Cannot plot scatter plot for the provided " + "explanation as it contains non-numeric values." + ) + else: + shap.plots.scatter(self.explanation) + else: + # Type checking for mypy + assert isinstance(self._bg_data, pd.DataFrame) + + mask = self._bg_data.iloc[0].apply(lambda x: not isinstance(x, str)) + number_enum = np.where(mask)[0].tolist() + + if len(number_enum) < len(self._bg_data.iloc[0]): + warnings.warn( + "Cannot plot SHAP scatter plot for all " + "parameters as some contain non-numeric values." + ) + shap.plots.scatter(self.explanation[:, number_enum], **kwargs) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index d3810056d..aab6303c8 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -45,7 +45,7 @@ def _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap): shap_insights = SHAPInsight.from_campaign( campaign, explainer_class=explainer_cls, - computational_representation=use_comp_rep, + use_comp_rep=use_comp_rep, ) assert isinstance(shap_insights, insights.SHAPInsight) assert isinstance( @@ -63,24 +63,17 @@ def _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap): ): shap_insights._init_explanation(df) except TypeError as e: - if ( - "The selected explainer class does not support the campaign surrogate." - in str(e) - ): - pass + if "The selected explainer class" in str(e): + pytest.xfail("Unsupported model/explainer combination") else: raise e except NotImplementedError as e: if ( - "The selected explainer class does not support experimental " - "representation. Switch to computational representation or " - "use a different explainer (e.g. the default " - "shap.KernelExplainer)." - in str(e) + "The selected explainer class" in str(e) and not use_comp_rep and not isinstance(explainer_cls, shap.explainers.KernelExplainer) ): - pass + pytest.xfail("Exp. rep. not supported") else: raise e @@ -131,7 +124,7 @@ def test_shap_insight_plots(campaign, use_comp_rep, plot_type): run_iterations(campaign, n_iterations=2, batch_size=1) shap_insights = SHAPInsight.from_campaign( campaign, - computational_representation=use_comp_rep, + use_comp_rep=use_comp_rep, ) with mock.patch("matplotlib.pyplot.show"): shap_insights.plot(plot_type) From 67fdf8d67fbe9113b298a5f96587c002a692d582 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 27 Dec 2024 21:07:00 +0100 Subject: [PATCH 24/92] Package housekeeping --- CHANGELOG.md | 4 ++-- README.md | 4 ++-- baybe/_optional/insights.py | 2 +- baybe/insights/__init__.py | 2 -- pyproject.toml | 3 +-- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d19f210..6977a6b83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- `insights` dependency group -- SHAP explanations +- Optional `insights` dependency group +- SHAP explanations via the new `SHAPInsight` class - `allow_missing` and `allow_extra` keyword arguments to `Objective.transform` - Example for a traditional mixture - `add_noise_to_perturb_degenerate_rows` utility diff --git a/README.md b/README.md index 70aef9cdb..ec51e55f9 100644 --- a/README.md +++ b/README.md @@ -296,8 +296,8 @@ The available groups are: - `lint`: Required for linting and formatting. - `mypy`: Required for static type checking. - `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). -- `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/) -- `insights`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/)pip install uv +- `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/). +- `insights`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/). - `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. - `test`: Required for running the tests. - `benchmarking`: Required for running the benchmarking module. diff --git a/baybe/_optional/insights.py b/baybe/_optional/insights.py index 0f83786bf..9a9505789 100644 --- a/baybe/_optional/insights.py +++ b/baybe/_optional/insights.py @@ -1,4 +1,4 @@ -"""Optional import for insight subpackage.""" +"""Optional import for the insights subpackage.""" from baybe.exceptions import OptionalImportError diff --git a/baybe/insights/__init__.py b/baybe/insights/__init__.py index 0cc340c51..767d0288a 100644 --- a/baybe/insights/__init__.py +++ b/baybe/insights/__init__.py @@ -3,10 +3,8 @@ from baybe._optional.info import INSIGHTS_INSTALLED if INSIGHTS_INSTALLED: - from baybe.insights.base import Insight from baybe.insights.shap import SHAPInsight __all__ = [ "SHAPInsight", - "Insight", ] diff --git a/pyproject.toml b/pyproject.toml index 0851ad239..baa03143c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,16 +79,15 @@ onnx = [ dev = [ "baybe[chem]", - "baybe[insights]", "baybe[docs]", "baybe[examples]", + "baybe[insights]", "baybe[lint]", "baybe[mypy]", "baybe[onnx]", "baybe[polars]", "baybe[simulation]", "baybe[test]", - "baybe[insights]", "baybe[benchmarking]", "pip-audit>=2.5.5", "tox-uv>=1.7.0", From 69cecd83002b1fb0aea9dfba8e48e5151cf73ccb Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Thu, 2 Jan 2025 17:53:19 +0100 Subject: [PATCH 25/92] Rework classes --- baybe/insights/base.py | 54 ++++-- baybe/insights/shap.py | 276 ++++++++++++++++-------------- tests/insights/test_shap.py | 324 ++++++++++++++++++------------------ 3 files changed, 347 insertions(+), 307 deletions(-) diff --git a/baybe/insights/base.py b/baybe/insights/base.py index 60a7e81ca..513105d87 100644 --- a/baybe/insights/base.py +++ b/baybe/insights/base.py @@ -1,28 +1,36 @@ """Base class for all insights.""" +from __future__ import annotations + from abc import ABC import pandas as pd +from attrs import define, field from baybe import Campaign -from baybe._optional.info import INSIGHTS_INSTALLED from baybe.objectives.base import Objective from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace import SearchSpace - -if INSIGHTS_INSTALLED: - pass +from baybe.surrogates.base import SurrogateProtocol +@define class Insight(ABC): """Base class for all insights.""" - def __init__(self, surrogate): - self.surrogate = surrogate + surrogate: SurrogateProtocol = field() + """The surrogate model that is supposed bo be analyzed.""" @classmethod - def from_campaign(cls, campaign: Campaign): - """Create an insight from a campaign.""" + def from_campaign(cls, campaign: Campaign) -> Insight: + """Create an insight from a campaign. + + Args: + campaign: A baybe Campaign object. + + Returns: + The Insight object. + """ return cls(campaign.get_surrogate()) @classmethod @@ -31,15 +39,29 @@ def from_recommender( recommender: BayesianRecommender, searchspace: SearchSpace, objective: Objective, - bg_data: pd.DataFrame, - ): - """Create an insight from a recommender.""" + measurements: pd.DataFrame, + ) -> Insight: + """Create an insight from a recommender. + + Args: + recommender: A model-based recommender. + searchspace: The search space used for recommendations. + objective: The objective of the recommendation. + measurements: The measurements in experimental representation. + + Returns: + The Insight object. + + Raises: + ValueError: If the provided recommender is not surrogate-based. + """ if not hasattr(recommender, "get_surrogate"): raise ValueError( - "The provided recommender does not provide a surrogate model." + f"The provided recommender of type '{recommender.__class__.__name__}' " + f"does not provide a surrogate model." ) - surrogate_model = recommender.get_surrogate(searchspace, objective, bg_data) - - return cls( - surrogate_model, + surrogate_model = recommender.get_surrogate( + searchspace, objective, measurements ) + + return cls(surrogate_model) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 24991b5b9..50516a407 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -3,10 +3,10 @@ import inspect import numbers import warnings -from typing import Any import numpy as np import pandas as pd +from attrs import define, field from typing_extensions import override from baybe import Campaign @@ -18,79 +18,94 @@ from baybe.utils.dataframe import to_tensor -class SHAPInsight(Insight): - """Base class for all SHAP insights.""" - - DEFAULT_SHAP_PLOTS = { - "bar", - "scatter", - "heatmap", - "force", - "beeswarm", - } +def _get_explainer_maps() -> ( + tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]] +): + """Get maps for SHAP and non-SHAP explainers. - @staticmethod - def _get_explainer_maps() -> ( - tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]] - ): - """Get explainer maps for SHAP and non-SHAP explainers. + Returns: + The maps for SHAP and non-SHAP explainers. + """ + EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] - Returns: - The explainer maps for SHAP and non-SHAP explainers. - """ - EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] - - def _has_required_init_parameters(cls): - """Check if non-shap initializer has required standard parameters.""" - REQUIRED_PARAMETERS = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == REQUIRED_PARAMETERS - - shap_explainers = { - cls_name: getattr(shap.explainers, cls_name) - for cls_name in shap.explainers.__all__ - if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - } - - non_shap_explainers = { - cls_name: explainer - for cls_name in shap.explainers.other.__all__ - if _has_required_init_parameters( - explainer := getattr(shap.explainers.other, cls_name) - ) - and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - } + def _has_required_init_parameters(cls): + """Check if non-shap initializer has required standard parameters.""" + REQUIRED_PARAMETERS = ["self", "model", "data"] - return shap_explainers, non_shap_explainers + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) - SHAP_EXPLAINERS, NON_SHAP_EXPLAINERS = _get_explainer_maps() + return parameters[:3] == REQUIRED_PARAMETERS - ALL_EXPLAINERS = {**SHAP_EXPLAINERS, **NON_SHAP_EXPLAINERS} + shap_explainers = { + cls_name: getattr(shap.explainers, cls_name) + for cls_name in shap.explainers.__all__ + if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + } - def __init__( - self, - surrogate_model, - bg_data: pd.DataFrame, - explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer | str = "KernelExplainer", - use_comp_rep: bool = False, - ): - super().__init__(surrogate_model) - self._use_comp_rep = use_comp_rep - explainer_cls = ( - explainer_class - if not isinstance(explainer_class, str) - or explainer_class not in self.ALL_EXPLAINERS - else self.ALL_EXPLAINERS[explainer_class] - ) - self._is_shap_explainer = not explainer_cls.__module__.startswith( - "shap.explainers.other." + non_shap_explainers = { + cls_name: explainer + for cls_name in shap.explainers.other.__all__ + if _has_required_init_parameters( + explainer := getattr(shap.explainers.other, cls_name) ) - self._bg_data = bg_data - self._explained_data = explained_data - self.explainer = self._init_explainer(bg_data, explainer_cls) # type: ignore[arg-type] - self._explanation = None + and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) + } + + return shap_explainers, non_shap_explainers + + +SHAP_EXPLAINERS, NON_SHAP_EXPLAINERS = _get_explainer_maps() +ALL_EXPLAINERS = SHAP_EXPLAINERS | NON_SHAP_EXPLAINERS +SUPPORTED_SHAP_PLOTS = { + "bar", + "scatter", + "heatmap", + "force", + "beeswarm", +} + + +@define +class SHAPInsight(Insight): + """Class for SHAP-based feature importance insights. + + This also supports LIME and MAPLE explainers via ways provided by the shap module. + """ + + bg_data: pd.DataFrame = field() + """The background data set used to build the explainer.""" + + explained_data: pd.DataFrame | None = field(default=None) + """The data for which a SHAP explanation is generated.""" + + explainer_cls: type[shap.Explainer] | str = field( + default="KernelExplainer", + converter=lambda x: ALL_EXPLAINERS[x] if isinstance(x, str) else x, + ) + """The SHAP explainer class that is used to generate the explanation. + + Some non-SHAP explainers, like MAPLE and LIME, are also supported if they are + available via 'shap.explainers.other'. + """ + + use_comp_rep: bool = field(default=False) + """Flag for toggling in which representation the insight should be provided.""" + + _explainer: shap.Explainer | None = field(default=None, init=False) + """The explainer generated from the model and background data.""" + + _explanation: shap.Explanation | None = field(default=None, init=False) + """The explanation generated.""" + + def __attrs_post_init__(self): + """Initialize the explainer.""" + self._explainer = self._init_explainer(self.bg_data, self.explainer_cls) + + @property + def uses_shap_explainer(self): + """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" + return not self.explainer_cls.__module__.startswith("shap.explainers.other.") @override @classmethod @@ -98,16 +113,17 @@ def from_campaign( cls, campaign: Campaign, explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer | str = "KernelExplainer", + explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, ): """Create a SHAP insight from a campaign. Args: - campaign: The campaign to be used for the SHAP insight. - explained_data: The data set to be explained. If None, - all measurements from the campaign are used. - explainer_class: The explainer class to be used for the computation. + campaign: The campaign which holds the recommender and model. + explained_data: The data set to be explained. If None, all measurements + from the campaign are used. + explainer_cls: The SHAP explainer class that is used to generate the + explanation. use_comp_rep: Whether to analyze the model in computational representation (experimental representation otherwise). @@ -120,7 +136,7 @@ def from_campaign( """ if campaign.measurements.empty: raise ValueError( - f"The campaign does not contain any measurements. A {cls.__name__} " + f"The campaign does not contain any measurements. A '{cls.__name__}' " f"assumes there is mandatory background data in the form of " f"measurements as part of the campaign." ) @@ -129,7 +145,7 @@ def from_campaign( return cls( campaign.get_surrogate(), bg_data=campaign.searchspace.transform(data) if use_comp_rep else data, - explainer_class=explainer_class, + explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, explained_data=explained_data, ) @@ -141,22 +157,22 @@ def from_recommender( recommender: BayesianRecommender, searchspace: SearchSpace, objective: Objective, - bg_data: pd.DataFrame, + measurements: pd.DataFrame, explained_data: pd.DataFrame | None = None, - explainer_class: shap.Explainer | str = "KernelExplainer", + explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, ): """Create a SHAP insight from a recommender. Args: - recommender: The recommender to be used for the SHAP insight. + recommender: The model-based recommender. searchspace: The searchspace for the recommender. objective: The objective for the recommender. - bg_data: The background data set for Explainer. - This is also the measurement data set for the recommender. + measurements: The background data set for Explainer. + This is used the measurement data set for the recommender. explained_data: The data set to be explained. If None, the background data set is used. - explainer_class: The explainer class. + explainer_cls: The explainer class. use_comp_rep: Whether to analyze the model in computational representation (experimental representation otherwise). @@ -169,29 +185,36 @@ def from_recommender( """ if not hasattr(recommender, "get_surrogate"): raise ValueError( - "The provided recommender does not provide a surrogate model." + f"The provided recommender does not provide a surrogate model. A " + f"'{cls.__name__}' needs a surrogate model and thus only works with " + f"model-based recommenders." ) - surrogate_model = recommender.get_surrogate(searchspace, objective, bg_data) + surrogate_model = recommender.get_surrogate( + searchspace, objective, measurements + ) return cls( surrogate_model, - bg_data=searchspace.transform(bg_data) if use_comp_rep else bg_data, + bg_data=searchspace.transform(measurements) + if use_comp_rep + else measurements, explained_data=explained_data, - explainer_class=explainer_class, + explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) def _init_explainer( self, bg_data: pd.DataFrame, - explainer_class: type[shap.Explainer] = shap.KernelExplainer, + explainer_cls: type[shap.Explainer] = shap.KernelExplainer, **kwargs, ) -> shap.Explainer: """Create a SHAP explainer. Args: bg_data: The background data set. - explainer_class: The explainer class to be used. + explainer_cls: The SHAP explainer class that is used to generate the + explanation. **kwargs: Additional keyword arguments to be passed to the explainer. Returns: @@ -204,16 +227,16 @@ def _init_explainer( TypeError: If the provided explainer class does not support the campaign surrogate. """ - if not self._is_shap_explainer and not self._use_comp_rep: + if not self.uses_shap_explainer and not self.use_comp_rep: raise NotImplementedError( - "Experimental representation is not " - "supported for non-Kernel SHAP explainer." + "Experimental representation is not supported for non-Kernel SHAP " + "explainer." ) if bg_data.empty: raise ValueError("The provided background data set is empty.") - if self._use_comp_rep: + if self.use_comp_rep: def model(x): tensor = to_tensor(x) @@ -229,19 +252,20 @@ def model(x): return output.detach().numpy() try: - shap_explainer = explainer_class(model, bg_data, **kwargs) - """Explain first two data points to ensure that the explainer is working.""" - if self._is_shap_explainer: - shap_explainer(self._bg_data.iloc[0:1]) + shap_explainer = explainer_cls(model, bg_data, **kwargs) + + # Explain first two data points to ensure that the explainer is working + if self.uses_shap_explainer: + shap_explainer(self.bg_data.iloc[0:1]) except shap.utils._exceptions.InvalidModelError: raise TypeError( - f"The selected explainer class {explainer_class} does not support the " + f"The selected explainer class {explainer_cls} does not support the " f"provided surrogate model." ) except TypeError as e: - if "not supported for the input types" in str(e) and not self._use_comp_rep: + if "not supported for the input types" in str(e) and not self.use_comp_rep: raise NotImplementedError( - f"The selected explainer class {explainer_class} does not support " + f"The selected explainer class {explainer_cls} does not support " f"the experimental representation. Switch to computational " f"representation or use a different explainer (e.g. the default " f"shap.KernelExplainer)." @@ -268,59 +292,55 @@ def _init_explanation( parameters as the SHAP explainer background """ if explained_data is None: - explained_data = self._bg_data - elif not self._bg_data.shape[1] == explained_data.shape[1]: + explained_data = self.bg_data + elif not self.bg_data.shape[1] == explained_data.shape[1]: raise ValueError( "The provided data does not have the same amount of " "parameters as the shap explainer background." ) # Type checking for mypy - assert isinstance(explained_data, pd.DataFrame) + assert self._explainer is not None - if not self._is_shap_explainer: + if not self.uses_shap_explainer: # Return attributions for non-SHAP explainers - if self.explainer.__module__.endswith("maple"): + if self._explainer.__module__.endswith("maple"): # Additional argument for maple to increase comparability to SHAP - attributions = self.explainer.attributions( + attributions = self._explainer.attributions( explained_data, multiply_by_input=True )[0] else: - attributions = self.explainer.attributions(explained_data)[0] + attributions = self._explainer.attributions(explained_data)[0] explanations = shap.Explanation( values=attributions, - base_values=self.explainer.model(self._bg_data).mean(), + base_values=self._explainer.model(self.bg_data).mean(), data=explained_data, feature_names=explained_data.columns.values, ) return explanations else: - explanations = self.explainer(explained_data) + explanations = self._explainer(explained_data) """Ensure that the explanation object is of the correct dimensionality.""" if len(explanations.shape) == 2: return explanations if len(explanations.shape) == 3: return explanations[:, :, 0] - raise ValueError( - "The Explanation has an invalid " - f"dimensionality of {len(explanations.shape)}." + raise RuntimeError( + f"The explanation obtained for {self.__class__.__name__} has an unexpected " + f"invalid dimensionality of {len(explanations.shape)}." ) @property def explanation(self) -> shap.Explanation: - """Get the SHAP explanation object. Uses lazy evaluation. - - Returns: - shap.Explanation: The SHAP explanation object. - """ + """Get the SHAP explanation object. Uses lazy evaluation.""" if self._explanation is None: self._explanation = self._init_explanation() return self._explanation - def plot(self, plot_type: str, **kwargs) -> None: + def plot(self, plot_type: str, **kwargs: dict) -> None: """Plot the Shapley values using the provided plot type. Args: @@ -329,7 +349,7 @@ def plot(self, plot_type: str, **kwargs) -> None: **kwargs: Additional keyword arguments to be passed to the plot function. Raises: - ValueError: If the provided plot type is not supported + ValueError: If the provided plot type is not supported. """ if plot_type == "scatter": self._plot_shap_scatter(**kwargs) @@ -337,15 +357,15 @@ def plot(self, plot_type: str, **kwargs) -> None: plot = getattr(shap.plots, plot_type, None) if ( - plot is None - or not callable(plot) - or plot_type not in self.DEFAULT_SHAP_PLOTS + (plot_type not in SUPPORTED_SHAP_PLOTS) + or (plot is None) + or (not callable(plot)) ): - raise ValueError(f"Invalid plot type: {plot_type}") + raise ValueError(f"Invalid plot type: '{plot_type}'.") plot(self.explanation, **kwargs) - def _plot_shap_scatter(self, **kwargs: Any) -> None: + def _plot_shap_scatter(self, **kwargs: dict) -> None: """Plot the Shapley values as scatter plot while leaving out string values. Args: @@ -355,8 +375,8 @@ def _plot_shap_scatter(self, **kwargs: Any) -> None: def is_not_numeric_column(col): return np.array([not isinstance(v, numbers.Number) for v in col]).any() - if np.ndim(self._bg_data) == 1: - if is_not_numeric_column(self._bg_data): + if np.ndim(self.bg_data) == 1: + if is_not_numeric_column(self.bg_data): warnings.warn( "Cannot plot scatter plot for the provided " "explanation as it contains non-numeric values." @@ -365,14 +385,14 @@ def is_not_numeric_column(col): shap.plots.scatter(self.explanation) else: # Type checking for mypy - assert isinstance(self._bg_data, pd.DataFrame) + assert isinstance(self.bg_data, pd.DataFrame) - mask = self._bg_data.iloc[0].apply(lambda x: not isinstance(x, str)) + mask = self.bg_data.iloc[0].apply(lambda x: not isinstance(x, str)) number_enum = np.where(mask)[0].tolist() - if len(number_enum) < len(self._bg_data.iloc[0]): + if len(number_enum) < len(self.bg_data.iloc[0]): warnings.warn( - "Cannot plot SHAP scatter plot for all " - "parameters as some contain non-numeric values." + "Cannot plot SHAP scatter plot for all parameters as some contain " + "non-numeric values." ) shap.plots.scatter(self.explanation[:, number_enum], **kwargs) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index aab6303c8..0b2122989 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -1,163 +1,161 @@ -"""Tests for insights subpackage.""" - -from unittest import mock - -import pandas as pd -import pytest -from pytest import param - -from baybe._optional.info import INSIGHTS_INSTALLED -from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender -from baybe.recommenders.pure.bayesian.base import BayesianRecommender -from baybe.searchspace import SearchSpaceType -from baybe.utils.basic import get_subclasses -from tests.conftest import run_iterations - -pytestmark = pytest.mark.skipif( - not INSIGHTS_INSTALLED, reason="Optional insights dependency not installed." -) - -if INSIGHTS_INSTALLED: - import shap - - from baybe import insights - from baybe.insights.shap import SHAPInsight - - default_shap_plots = SHAPInsight.DEFAULT_SHAP_PLOTS - shap_explainers, non_shap_explainers = ( - SHAPInsight.SHAP_EXPLAINERS, - SHAPInsight.NON_SHAP_EXPLAINERS, - ) -else: - default_shap_plots, shap_explainers, non_shap_explainers = [], [], [] - -valid_hybrid_bayesian_recommenders = [ - param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") - for cls in get_subclasses(BayesianRecommender) - if cls.compatibility == SearchSpaceType.HYBRID -] - - -def _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap): - """Helper function for general SHAP explainer tests.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - try: - shap_insights = SHAPInsight.from_campaign( - campaign, - explainer_class=explainer_cls, - use_comp_rep=use_comp_rep, - ) - assert isinstance(shap_insights, insights.SHAPInsight) - assert isinstance( - shap_insights.explainer, - SHAPInsight.ALL_EXPLAINERS[explainer_cls], - ) - assert shap_insights._is_shap_explainer == is_shap - shap_explanation = shap_insights.explanation - assert isinstance(shap_explanation, shap.Explanation) - df = pd.DataFrame({"Num_disc_1": [0, 2]}) - with pytest.raises( - ValueError, - match="The provided data does not have the same " - "amount of parameters as the shap explainer background.", - ): - shap_insights._init_explanation(df) - except TypeError as e: - if "The selected explainer class" in str(e): - pytest.xfail("Unsupported model/explainer combination") - else: - raise e - except NotImplementedError as e: - if ( - "The selected explainer class" in str(e) - and not use_comp_rep - and not isinstance(explainer_cls, shap.explainers.KernelExplainer) - ): - pytest.xfail("Exp. rep. not supported") - else: - raise e - - -@pytest.mark.slow -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -@pytest.mark.parametrize("explainer_cls", shap_explainers) -@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize( - "parameter_names", - [ - ["Conti_finite1", "Conti_finite2"], - ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], - ], - ids=["continuous_params", "hybrid_params"], -) -def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): - """Test the explain functionalities with measurements.""" - _test_shap_insights(campaign, explainer_cls, use_comp_rep, is_shap=True) - - -@pytest.mark.parametrize("explainer_cls", non_shap_explainers) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], - ids=["hybrid_params"], -) -def test_non_shapley_explainers(campaign, explainer_cls): - """Test the explain functionalities with the non-SHAP explainer MAPLE.""" - """Test the non-SHAP explainer in computational representation.""" - _test_shap_insights(campaign, explainer_cls, use_comp_rep=True, is_shap=False) - - -@pytest.mark.slow -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize("plot_type", default_shap_plots) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], - ids=["hybrid_params"], -) -def test_shap_insight_plots(campaign, use_comp_rep, plot_type): - """Test the default SHAP plots.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - shap_insights = SHAPInsight.from_campaign( - campaign, - use_comp_rep=use_comp_rep, - ) - with mock.patch("matplotlib.pyplot.show"): - shap_insights.plot(plot_type) - - -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -def test_updated_campaign_explanations(campaign): - """Test explanations for campaigns with updated measurements.""" - with pytest.raises( - ValueError, - match="The campaign does not contain any measurements.", - ): - shap_insights = SHAPInsight.from_campaign(campaign) - run_iterations(campaign, n_iterations=2, batch_size=1) - shap_insights = SHAPInsight.from_campaign(campaign) - explanation_two_iter = shap_insights.explanation - run_iterations(campaign, n_iterations=2, batch_size=1) - shap_insights = SHAPInsight.from_campaign(campaign) - explanation_four_iter = shap_insights.explanation - assert explanation_two_iter != explanation_four_iter - - -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -def test_shap_insights_from_recommender(campaign): - """Test the creation of SHAP insights from a recommender.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - recommender = campaign.recommender.recommender - shap_insight = SHAPInsight.from_recommender( - recommender, - campaign.searchspace, - campaign.objective, - campaign.measurements, - ) - assert isinstance(shap_insight, insights.SHAPInsight) +"""Tests for insights subpackage.""" + +from unittest import mock + +import pandas as pd +import pytest +from pytest import param + +from baybe._optional.info import INSIGHTS_INSTALLED +from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.searchspace import SearchSpaceType +from baybe.utils.basic import get_subclasses +from tests.conftest import run_iterations + +pytestmark = pytest.mark.skipif( + not INSIGHTS_INSTALLED, reason="Optional 'insights' dependency not installed." +) + +if INSIGHTS_INSTALLED: + from baybe import insights + from baybe._optional.insights import shap + from baybe.insights.shap import ( + ALL_EXPLAINERS, + NON_SHAP_EXPLAINERS, + SHAP_EXPLAINERS, + SUPPORTED_SHAP_PLOTS, + SHAPInsight, + ) + + +valid_hybrid_bayesian_recommenders = [ + param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") + for cls in get_subclasses(BayesianRecommender) + if cls.compatibility == SearchSpaceType.HYBRID +] + + +def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): + """Helper function for general SHAP explainer tests.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + try: + shap_insight = SHAPInsight.from_campaign( + campaign, + explainer_cls=explainer_cls, + use_comp_rep=use_comp_rep, + ) + assert isinstance(shap_insight, insights.SHAPInsight) + assert isinstance( + shap_insight._explainer, + ALL_EXPLAINERS[explainer_cls], + ) + assert shap_insight.uses_shap_explainer == is_shap + shap_explanation = shap_insight.explanation + assert isinstance(shap_explanation, shap.Explanation) + df = pd.DataFrame({"Num_disc_1": [0, 2]}) + with pytest.raises( + ValueError, + match="The provided data does not have the same " + "amount of parameters as the shap explainer background.", + ): + shap_insight._init_explanation(df) + except TypeError as e: + if "The selected explainer class" in str(e): + pytest.xfail("Unsupported model/explainer combination") + else: + raise e + except NotImplementedError as e: + if ( + "The selected explainer class" in str(e) + and not use_comp_rep + and not isinstance(explainer_cls, shap.explainers.KernelExplainer) + ): + pytest.xfail("Exp. rep. not supported") + else: + raise e + + +@pytest.mark.slow +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize("explainer_cls", SHAP_EXPLAINERS) +@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) +@pytest.mark.parametrize( + "parameter_names", + [ + ["Conti_finite1", "Conti_finite2"], + ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], + ], + ids=["continuous_params", "hybrid_params"], +) +def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): + """Test the explain functionalities with measurements.""" + _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap=True) + + +@pytest.mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], + ids=["hybrid_params"], +) +def test_non_shapley_explainers(campaign, explainer_cls): + """Test the explain functionalities with the non-SHAP explainer MAPLE.""" + """Test the non-SHAP explainer in computational representation.""" + _test_shap_insight(campaign, explainer_cls, use_comp_rep=True, is_shap=False) + + +@pytest.mark.slow +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) +@pytest.mark.parametrize("plot_type", SUPPORTED_SHAP_PLOTS) +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], + ids=["hybrid_params"], +) +def test_shap_insight_plots(campaign, use_comp_rep, plot_type): + """Test the default SHAP plots.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + shap_insight = SHAPInsight.from_campaign( + campaign, + use_comp_rep=use_comp_rep, + ) + with mock.patch("matplotlib.pyplot.show"): + shap_insight.plot(plot_type) + + +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +def test_updated_campaign_explanations(campaign): + """Test explanations for campaigns with updated measurements.""" + with pytest.raises( + ValueError, + match="The campaign does not contain any measurements.", + ): + shap_insight = SHAPInsight.from_campaign(campaign) + run_iterations(campaign, n_iterations=2, batch_size=1) + shap_insight = SHAPInsight.from_campaign(campaign) + explanation_two_iter = shap_insight.explanation + run_iterations(campaign, n_iterations=2, batch_size=1) + shap_insight = SHAPInsight.from_campaign(campaign) + explanation_four_iter = shap_insight.explanation + assert explanation_two_iter != explanation_four_iter + + +@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) +@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) +def test_shap_insight_from_recommender(campaign): + """Test the creation of SHAP insights from a recommender.""" + run_iterations(campaign, n_iterations=2, batch_size=1) + recommender = campaign.recommender.recommender + shap_insight = SHAPInsight.from_recommender( + recommender, + campaign.searchspace, + campaign.objective, + campaign.measurements, + ) + assert isinstance(shap_insight, insights.SHAPInsight) From a6eef091481d4bfa1c0ba123bf9f935b1060ceea Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Thu, 2 Jan 2025 18:13:33 +0100 Subject: [PATCH 26/92] Update tests --- baybe/insights/shap.py | 15 +++- tests/conftest.py | 11 ++- tests/insights/test_shap.py | 160 ++++++++++++++++++------------------ 3 files changed, 99 insertions(+), 87 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 50516a407..3f30b8c65 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -1,5 +1,7 @@ """SHAP insights.""" +from __future__ import annotations + import inspect import numbers import warnings @@ -26,7 +28,14 @@ def _get_explainer_maps() -> ( Returns: The maps for SHAP and non-SHAP explainers. """ - EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"] + EXCLUDED_EXPLAINER_KEYWORDS = [ + "Tree", + "GPU", + "Gradient", + "Sampling", + "Deep", + "Linear", + ] def _has_required_init_parameters(cls): """Check if non-shap initializer has required standard parameters.""" @@ -115,7 +124,7 @@ def from_campaign( explained_data: pd.DataFrame | None = None, explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, - ): + ) -> SHAPInsight: """Create a SHAP insight from a campaign. Args: @@ -161,7 +170,7 @@ def from_recommender( explained_data: pd.DataFrame | None = None, explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, - ): + ) -> SHAPInsight: """Create a SHAP insight from a recommender. Args: diff --git a/tests/conftest.py b/tests/conftest.py index c82630cd4..431364bb7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -167,7 +167,7 @@ def fixture_batch_size(request): @pytest.fixture( params=[5, pytest.param(8, marks=pytest.mark.slow)], name="n_grid_points", - ids=["grid5", "grid8"], + ids=["g5", "g8"], ) def fixture_n_grid_points(request): """Number of grid points used in e.g. the mixture tests. @@ -591,6 +591,13 @@ def fixture_campaign(parameters, constraints, recommender, objective): ) +@pytest.fixture(name="ongoing_campaign") +def fixture_ongoing_campaign(campaign, n_iterations, batch_size): + """Returns a campaign that already ran for several iterations.""" + run_iterations(campaign, n_iterations, batch_size) + return campaign + + @pytest.fixture(name="searchspace") def fixture_searchspace(parameters, constraints): """Returns a searchspace.""" @@ -880,8 +887,6 @@ def fixture_default_onnx_surrogate(onnx_str) -> CustomONNXSurrogate: # Reusables -# TODO consider turning this into a fixture returning a campaign after running some -# fake iterations @retry( stop=stop_after_attempt(5), retry=retry_any( diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 0b2122989..bfbba67fa 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -4,20 +4,29 @@ import pandas as pd import pytest -from pytest import param +from pytest import mark -from baybe._optional.info import INSIGHTS_INSTALLED -from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender -from baybe.recommenders.pure.bayesian.base import BayesianRecommender -from baybe.searchspace import SearchSpaceType -from baybe.utils.basic import get_subclasses +from baybe._optional.info import SHAP_INSTALLED from tests.conftest import run_iterations -pytestmark = pytest.mark.skipif( - not INSIGHTS_INSTALLED, reason="Optional 'insights' dependency not installed." -) +# File-wide parameterization settings +pytestmark = [ + mark.skipif(not SHAP_INSTALLED, reason="Optional shap package not installed."), + mark.parametrize("n_grid_points", [5], ids=["g5"]), + mark.parametrize("n_iterations", [2], ids=["i2"]), + mark.parametrize("batch_size", [2], ids=["b2"]), + mark.parametrize( + "parameter_names", + [ + ["Conti_finite1", "Conti_finite2"], + ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], + ], + ids=["conti_params", "hybrid_params"], + ), +] + -if INSIGHTS_INSTALLED: +if SHAP_INSTALLED: from baybe import insights from baybe._optional.insights import shap from baybe.insights.shap import ( @@ -27,19 +36,18 @@ SUPPORTED_SHAP_PLOTS, SHAPInsight, ) - - -valid_hybrid_bayesian_recommenders = [ - param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}") - for cls in get_subclasses(BayesianRecommender) - if cls.compatibility == SearchSpaceType.HYBRID -] +else: + ALL_EXPLAINERS = [] + NON_SHAP_EXPLAINERS = [] + SHAP_EXPLAINERS = [] + SUPPORTED_SHAP_PLOTS = [] def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): """Helper function for general SHAP explainer tests.""" - run_iterations(campaign, n_iterations=2, batch_size=1) + # run_iterations(campaign, n_iterations=2, batch_size=5) try: + # Sanity check explainer shap_insight = SHAPInsight.from_campaign( campaign, explainer_cls=explainer_cls, @@ -51,15 +59,10 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): ALL_EXPLAINERS[explainer_cls], ) assert shap_insight.uses_shap_explainer == is_shap + + # Sanity check explanation shap_explanation = shap_insight.explanation assert isinstance(shap_explanation, shap.Explanation) - df = pd.DataFrame({"Num_disc_1": [0, 2]}) - with pytest.raises( - ValueError, - match="The provided data does not have the same " - "amount of parameters as the shap explainer background.", - ): - shap_insight._init_explanation(df) except TypeError as e: if "The selected explainer class" in str(e): pytest.xfail("Unsupported model/explainer combination") @@ -76,86 +79,81 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): raise e -@pytest.mark.slow -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -@pytest.mark.parametrize("explainer_cls", SHAP_EXPLAINERS) -@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize( - "parameter_names", - [ - ["Conti_finite1", "Conti_finite2"], - ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], - ], - ids=["continuous_params", "hybrid_params"], -) -def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): +@mark.slow +@mark.parametrize("explainer_cls", SHAP_EXPLAINERS) +@mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) +def test_shap_explainers(ongoing_campaign, explainer_cls, use_comp_rep): """Test the explain functionalities with measurements.""" - _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap=True) + _test_shap_insight(ongoing_campaign, explainer_cls, use_comp_rep, is_shap=True) -@pytest.mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], - ids=["hybrid_params"], -) -def test_non_shapley_explainers(campaign, explainer_cls): +@mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS) +def test_non_shap_explainers(ongoing_campaign, explainer_cls): """Test the explain functionalities with the non-SHAP explainer MAPLE.""" """Test the non-SHAP explainer in computational representation.""" - _test_shap_insight(campaign, explainer_cls, use_comp_rep=True, is_shap=False) - - -@pytest.mark.slow -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@pytest.mark.parametrize("plot_type", SUPPORTED_SHAP_PLOTS) -@pytest.mark.parametrize( - "parameter_names", - [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], - ids=["hybrid_params"], -) -def test_shap_insight_plots(campaign, use_comp_rep, plot_type): + _test_shap_insight( + ongoing_campaign, explainer_cls, use_comp_rep=True, is_shap=False + ) + + +@mark.slow +@mark.parametrize("explainer_cls", ["KernelExplainer"], ids=["KernelExplainer"]) +@mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) +def test_invalid_explained_data(ongoing_campaign, explainer_cls, use_comp_rep): + """Test invalid explained data.""" + shap_insight = SHAPInsight.from_campaign( + ongoing_campaign, + explainer_cls=explainer_cls, + use_comp_rep=use_comp_rep, + ) + df = pd.DataFrame({"Num_disc_1": [0, 2]}) + with pytest.raises( + ValueError, + match="The provided data does not have the same amount of parameters as the " + "shap explainer background.", + ): + shap_insight._init_explanation(df) + + +@mark.slow +@mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) +@mark.parametrize("plot_type", SUPPORTED_SHAP_PLOTS) +def test_plots(ongoing_campaign, use_comp_rep, plot_type): """Test the default SHAP plots.""" - run_iterations(campaign, n_iterations=2, batch_size=1) shap_insight = SHAPInsight.from_campaign( - campaign, + ongoing_campaign, use_comp_rep=use_comp_rep, ) with mock.patch("matplotlib.pyplot.show"): shap_insight.plot(plot_type) -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -def test_updated_campaign_explanations(campaign): +def test_updated_campaign_explanations(campaign, n_iterations, batch_size): """Test explanations for campaigns with updated measurements.""" with pytest.raises( ValueError, match="The campaign does not contain any measurements.", ): - shap_insight = SHAPInsight.from_campaign(campaign) - run_iterations(campaign, n_iterations=2, batch_size=1) + SHAPInsight.from_campaign(campaign) + + run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size) shap_insight = SHAPInsight.from_campaign(campaign) - explanation_two_iter = shap_insight.explanation - run_iterations(campaign, n_iterations=2, batch_size=1) + explanation_1 = shap_insight.explanation + + run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size) shap_insight = SHAPInsight.from_campaign(campaign) - explanation_four_iter = shap_insight.explanation - assert explanation_two_iter != explanation_four_iter + explanation_2 = shap_insight.explanation + + assert explanation_1 != explanation_2, "SHAP explanations should not be identical." -@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) -@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -def test_shap_insight_from_recommender(campaign): +def test_creation_from_recommender(ongoing_campaign): """Test the creation of SHAP insights from a recommender.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - recommender = campaign.recommender.recommender + recommender = ongoing_campaign.recommender.recommender shap_insight = SHAPInsight.from_recommender( recommender, - campaign.searchspace, - campaign.objective, - campaign.measurements, + ongoing_campaign.searchspace, + ongoing_campaign.objective, + ongoing_campaign.measurements, ) assert isinstance(shap_insight, insights.SHAPInsight) From b93a376ed75fbb35c5040ee639f507418fad2164 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 3 Jan 2025 11:27:35 +0100 Subject: [PATCH 27/92] Enhance error message --- baybe/insights/shap.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 3f30b8c65..526a00462 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -68,10 +68,10 @@ def _has_required_init_parameters(cls): ALL_EXPLAINERS = SHAP_EXPLAINERS | NON_SHAP_EXPLAINERS SUPPORTED_SHAP_PLOTS = { "bar", - "scatter", - "heatmap", - "force", "beeswarm", + "force", + "heatmap", + "scatter", } @@ -354,7 +354,7 @@ def plot(self, plot_type: str, **kwargs: dict) -> None: Args: plot_type: The type of plot to be created. Supported types are: - "bar", "scatter", "heatmap", "force", "beeswarm". + "bar", "beeswarm", "force", "heatmap", "scatter". **kwargs: Additional keyword arguments to be passed to the plot function. Raises: @@ -370,7 +370,10 @@ def plot(self, plot_type: str, **kwargs: dict) -> None: or (plot is None) or (not callable(plot)) ): - raise ValueError(f"Invalid plot type: '{plot_type}'.") + raise ValueError( + f"Invalid plot type: '{plot_type}'. Available options: " + f"{SUPPORTED_SHAP_PLOTS}." + ) plot(self.explanation, **kwargs) From 271645f5753fcb6ebf10db4181d28a1f70b3ec6b Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 3 Jan 2025 11:44:23 +0100 Subject: [PATCH 28/92] Add special handling for Lime --- baybe/insights/shap.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 526a00462..146304a40 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -260,6 +260,11 @@ def model(x): return output.detach().numpy() + # Handle special settings + if "Lime" in explainer_cls.__name__: + # Lime default mode is otherwise set to 'classification' + kwargs["mode"] = "regression" + try: shap_explainer = explainer_cls(model, bg_data, **kwargs) From 59970957985c0c5f327f7753b0f7ce6d60657129 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 Jan 2025 12:05:16 +0100 Subject: [PATCH 29/92] Rename bg_data to background_data --- baybe/insights/shap.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 146304a40..d2f093f47 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -82,7 +82,7 @@ class SHAPInsight(Insight): This also supports LIME and MAPLE explainers via ways provided by the shap module. """ - bg_data: pd.DataFrame = field() + background_data: pd.DataFrame = field() """The background data set used to build the explainer.""" explained_data: pd.DataFrame | None = field(default=None) @@ -109,7 +109,7 @@ class SHAPInsight(Insight): def __attrs_post_init__(self): """Initialize the explainer.""" - self._explainer = self._init_explainer(self.bg_data, self.explainer_cls) + self._explainer = self._init_explainer(self.background_data, self.explainer_cls) @property def uses_shap_explainer(self): @@ -150,10 +150,11 @@ def from_campaign( f"measurements as part of the campaign." ) data = campaign.measurements[[p.name for p in campaign.parameters]].copy() + background_data = campaign.searchspace.transform(data) if use_comp_rep else data return cls( campaign.get_surrogate(), - bg_data=campaign.searchspace.transform(data) if use_comp_rep else data, + background_data=background_data, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, explained_data=explained_data, @@ -204,7 +205,7 @@ def from_recommender( return cls( surrogate_model, - bg_data=searchspace.transform(measurements) + background_data=searchspace.transform(measurements) if use_comp_rep else measurements, explained_data=explained_data, @@ -214,14 +215,14 @@ def from_recommender( def _init_explainer( self, - bg_data: pd.DataFrame, + background_data: pd.DataFrame, explainer_cls: type[shap.Explainer] = shap.KernelExplainer, **kwargs, ) -> shap.Explainer: """Create a SHAP explainer. Args: - bg_data: The background data set. + background_data: The background data set. explainer_cls: The SHAP explainer class that is used to generate the explanation. **kwargs: Additional keyword arguments to be passed to the explainer. @@ -242,7 +243,7 @@ def _init_explainer( "explainer." ) - if bg_data.empty: + if background_data.empty: raise ValueError("The provided background data set is empty.") if self.use_comp_rep: @@ -255,7 +256,7 @@ def model(x): else: def model(x): - df = pd.DataFrame(x, columns=bg_data.columns) + df = pd.DataFrame(x, columns=background_data.columns) output = self.surrogate.posterior(df).mean return output.detach().numpy() @@ -266,11 +267,11 @@ def model(x): kwargs["mode"] = "regression" try: - shap_explainer = explainer_cls(model, bg_data, **kwargs) + shap_explainer = explainer_cls(model, background_data, **kwargs) # Explain first two data points to ensure that the explainer is working if self.uses_shap_explainer: - shap_explainer(self.bg_data.iloc[0:1]) + shap_explainer(self.background_data.iloc[0:1]) except shap.utils._exceptions.InvalidModelError: raise TypeError( f"The selected explainer class {explainer_cls} does not support the " @@ -306,8 +307,8 @@ def _init_explanation( parameters as the SHAP explainer background """ if explained_data is None: - explained_data = self.bg_data - elif not self.bg_data.shape[1] == explained_data.shape[1]: + explained_data = self.background_data + elif not self.background_data.shape[1] == explained_data.shape[1]: raise ValueError( "The provided data does not have the same amount of " "parameters as the shap explainer background." @@ -328,7 +329,7 @@ def _init_explanation( explanations = shap.Explanation( values=attributions, - base_values=self._explainer.model(self.bg_data).mean(), + base_values=self._explainer.model(self.background_data).mean(), data=explained_data, feature_names=explained_data.columns.values, ) @@ -392,8 +393,8 @@ def _plot_shap_scatter(self, **kwargs: dict) -> None: def is_not_numeric_column(col): return np.array([not isinstance(v, numbers.Number) for v in col]).any() - if np.ndim(self.bg_data) == 1: - if is_not_numeric_column(self.bg_data): + if np.ndim(self.background_data) == 1: + if is_not_numeric_column(self.background_data): warnings.warn( "Cannot plot scatter plot for the provided " "explanation as it contains non-numeric values." @@ -402,12 +403,12 @@ def is_not_numeric_column(col): shap.plots.scatter(self.explanation) else: # Type checking for mypy - assert isinstance(self.bg_data, pd.DataFrame) + assert isinstance(self.background_data, pd.DataFrame) - mask = self.bg_data.iloc[0].apply(lambda x: not isinstance(x, str)) + mask = self.background_data.iloc[0].apply(lambda x: not isinstance(x, str)) number_enum = np.where(mask)[0].tolist() - if len(number_enum) < len(self.bg_data.iloc[0]): + if len(number_enum) < len(self.background_data.iloc[0]): warnings.warn( "Cannot plot SHAP scatter plot for all parameters as some contain " "non-numeric values." From ec850ba42ede8c526ce441e7c85fa6e37186dcb6 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 Jan 2025 14:05:59 +0100 Subject: [PATCH 30/92] Add missing validators --- baybe/insights/shap.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index d2f093f47..f77d19bce 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd from attrs import define, field +from attrs.validators import instance_of, optional from typing_extensions import override from baybe import Campaign @@ -82,10 +83,12 @@ class SHAPInsight(Insight): This also supports LIME and MAPLE explainers via ways provided by the shap module. """ - background_data: pd.DataFrame = field() + background_data: pd.DataFrame = field(validator=instance_of(pd.DataFrame)) """The background data set used to build the explainer.""" - explained_data: pd.DataFrame | None = field(default=None) + explained_data: pd.DataFrame | None = field( + default=None, validator=optional(instance_of(pd.DataFrame)) + ) """The data for which a SHAP explanation is generated.""" explainer_cls: type[shap.Explainer] | str = field( @@ -98,7 +101,7 @@ class SHAPInsight(Insight): available via 'shap.explainers.other'. """ - use_comp_rep: bool = field(default=False) + use_comp_rep: bool = field(default=False, validator=instance_of(bool)) """Flag for toggling in which representation the insight should be provided.""" _explainer: shap.Explainer | None = field(default=None, init=False) From ae33f55207770eeac3152ee77edcd97d5bb3931d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 Jan 2025 14:06:13 +0100 Subject: [PATCH 31/92] Fix type annotations --- baybe/insights/shap.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index f77d19bce..46a90221e 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -5,6 +5,7 @@ import inspect import numbers import warnings +from typing import Any import numpy as np import pandas as pd @@ -91,7 +92,8 @@ class SHAPInsight(Insight): ) """The data for which a SHAP explanation is generated.""" - explainer_cls: type[shap.Explainer] | str = field( + # FIXME[typing]: https://github.com/python/mypy/issues/10998 + explainer_cls: type[shap.Explainer] = field( # type: ignore[assignment] default="KernelExplainer", converter=lambda x: ALL_EXPLAINERS[x] if isinstance(x, str) else x, ) @@ -115,7 +117,7 @@ def __attrs_post_init__(self): self._explainer = self._init_explainer(self.background_data, self.explainer_cls) @property - def uses_shap_explainer(self): + def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" return not self.explainer_cls.__module__.startswith("shap.explainers.other.") @@ -358,7 +360,7 @@ def explanation(self) -> shap.Explanation: return self._explanation - def plot(self, plot_type: str, **kwargs: dict) -> None: + def plot(self, plot_type: str, **kwargs: Any) -> None: """Plot the Shapley values using the provided plot type. Args: @@ -386,7 +388,7 @@ def plot(self, plot_type: str, **kwargs: dict) -> None: plot(self.explanation, **kwargs) - def _plot_shap_scatter(self, **kwargs: dict) -> None: + def _plot_shap_scatter(self, **kwargs: Any) -> None: """Plot the Shapley values as scatter plot while leaving out string values. Args: From c1d69a6496318a1aaef6459109ca842b71e00db2 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 Jan 2025 14:08:59 +0100 Subject: [PATCH 32/92] Replace unnecessary post init call --- baybe/insights/shap.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 46a90221e..12f6d5372 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -9,7 +9,7 @@ import numpy as np import pandas as pd -from attrs import define, field +from attrs import Factory, define, field from attrs.validators import instance_of, optional from typing_extensions import override @@ -106,16 +106,18 @@ class SHAPInsight(Insight): use_comp_rep: bool = field(default=False, validator=instance_of(bool)) """Flag for toggling in which representation the insight should be provided.""" - _explainer: shap.Explainer | None = field(default=None, init=False) + _explainer: shap.Explainer | None = field( + default=Factory( + lambda self: self._init_explainer(self.background_data, self.explainer_cls), + takes_self=True, + ), + init=False, + ) """The explainer generated from the model and background data.""" _explanation: shap.Explanation | None = field(default=None, init=False) """The explanation generated.""" - def __attrs_post_init__(self): - """Initialize the explainer.""" - self._explainer = self._init_explainer(self.background_data, self.explainer_cls) - @property def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" From e58c04a38e0cc0ef5dde1d1d5905d481048a469f Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 Jan 2025 14:22:25 +0100 Subject: [PATCH 33/92] Remove unnecessary guard clause The input type is BayesianRecommender (not Recommender), which by definition has a surrogate model. --- baybe/insights/shap.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 12f6d5372..6ba13de13 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -196,16 +196,7 @@ def from_recommender( Returns: The SHAP insight object. - - Raises: - ValueError: If the recommender has not implemented a "get_surrogate" method. """ - if not hasattr(recommender, "get_surrogate"): - raise ValueError( - f"The provided recommender does not provide a surrogate model. A " - f"'{cls.__name__}' needs a surrogate model and thus only works with " - f"model-based recommenders." - ) surrogate_model = recommender.get_surrogate( searchspace, objective, measurements ) From f5530bc09743401a2ada772a9a085d6257c6298b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 Jan 2025 14:38:21 +0100 Subject: [PATCH 34/92] Turn method guard clause into proper attribute validator --- baybe/insights/shap.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 6ba13de13..571bb9420 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -118,6 +118,14 @@ class SHAPInsight(Insight): _explanation: shap.Explanation | None = field(default=None, init=False) """The explanation generated.""" + @use_comp_rep.validator + def _validate_use_comp_rep(self, _, value: bool) -> None: + if not self.uses_shap_explainer and not value: + raise NotImplementedError( + "Experimental representation is not supported for non-Kernel SHAP " + "explainer." + ) + @property def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" @@ -229,18 +237,10 @@ def _init_explainer( shap.Explainer: The created explainer object. Raises: - NotImplementedError: If the provided explainer class does - not support the experimental representation. ValueError: If the provided background data set is empty. TypeError: If the provided explainer class does not support the campaign surrogate. """ - if not self.uses_shap_explainer and not self.use_comp_rep: - raise NotImplementedError( - "Experimental representation is not supported for non-Kernel SHAP " - "explainer." - ) - if background_data.empty: raise ValueError("The provided background data set is empty.") From 313f39422f1cd95b381551baa0ee33e7d465ddce Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 Jan 2025 14:54:36 +0100 Subject: [PATCH 35/92] Improve model function definition --- baybe/insights/shap.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 571bb9420..d19e1bc97 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -8,6 +8,7 @@ from typing import Any import numpy as np +import numpy.typing as npt import pandas as pd from attrs import Factory, define, field from attrs.validators import instance_of, optional @@ -244,20 +245,23 @@ def _init_explainer( if background_data.empty: raise ValueError("The provided background data set is empty.") + import torch + if self.use_comp_rep: - def model(x): + def model(x: npt.ArrayLike) -> np.ndarray: tensor = to_tensor(x) - output = self.surrogate._posterior_comp(tensor).mean + with torch.no_grad(): + output = self.surrogate._posterior_comp(tensor).mean + return output.numpy() - return output.detach().numpy() else: - def model(x): + def model(x: npt.ArrayLike) -> np.ndarray: df = pd.DataFrame(x, columns=background_data.columns) - output = self.surrogate.posterior(df).mean - - return output.detach().numpy() + with torch.no_grad(): + output = self.surrogate.posterior(df).mean + return output.numpy() # Handle special settings if "Lime" in explainer_cls.__name__: From e3bb895ec0a6d61840da0b886948cdf507dd8218 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Tue, 7 Jan 2025 09:06:33 +0100 Subject: [PATCH 36/92] Refactor plot methods to return plt.Axes if plots not shown directly. Removed "force" from SUPPORTED_SHAP_PLOTS. --- baybe/insights/shap.py | 50 ++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 146304a40..34bed36e2 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -6,6 +6,7 @@ import numbers import warnings +import matplotlib.pyplot as plt import numpy as np import pandas as pd from attrs import define, field @@ -69,7 +70,6 @@ def _has_required_init_parameters(cls): SUPPORTED_SHAP_PLOTS = { "bar", "beeswarm", - "force", "heatmap", "scatter", } @@ -336,7 +336,9 @@ def _init_explanation( else: explanations = self._explainer(explained_data) - """Ensure that the explanation object is of the correct dimensionality.""" + # Reduce dimensionality of explanations to 2D in case + # a 3D explanation is returned. This is the case for + # some explainers even if only one output is present. if len(explanations.shape) == 2: return explanations if len(explanations.shape) == 3: @@ -354,7 +356,7 @@ def explanation(self) -> shap.Explanation: return self._explanation - def plot(self, plot_type: str, **kwargs: dict) -> None: + def plot(self, plot_type: str, **kwargs: dict) -> None | plt.Axes: """Plot the Shapley values using the provided plot type. Args: @@ -362,32 +364,52 @@ def plot(self, plot_type: str, **kwargs: dict) -> None: "bar", "beeswarm", "force", "heatmap", "scatter". **kwargs: Additional keyword arguments to be passed to the plot function. + Returns: + None | plt.Axes: The plot object if 'show' is set to False. + Raises: ValueError: If the provided plot type is not supported. """ + # Extract the 'show' argument from the kwargs + show = kwargs.pop("show", True) + + plot = None + + # Special case for scatter plot if plot_type == "scatter": - self._plot_shap_scatter(**kwargs) + plot = self._plot_shap_scatter(show=show, **kwargs) + if not show: + return plot return None - plot = getattr(shap.plots, plot_type, None) + # Cases for all other plots + plot_func = getattr(shap.plots, plot_type, None) if ( (plot_type not in SUPPORTED_SHAP_PLOTS) - or (plot is None) - or (not callable(plot)) + or (plot_func is None) + or (not callable(plot_func)) ): raise ValueError( f"Invalid plot type: '{plot_type}'. Available options: " f"{SUPPORTED_SHAP_PLOTS}." ) - plot(self.explanation, **kwargs) + plot = plot_func(self.explanation, show=show, **kwargs) + if not show: + return plot + return None - def _plot_shap_scatter(self, **kwargs: dict) -> None: + def _plot_shap_scatter(self, show: bool = True, **kwargs: dict) -> None | plt.Axes: """Plot the Shapley values as scatter plot while leaving out string values. Args: + show: Whether to call plt.show() after plotting or not. **kwargs: Additional keyword arguments to be passed to the plot function. + + Returns: + None | plt.Axes: The plot object if 'show' is set to False. """ + plot = None def is_not_numeric_column(col): return np.array([not isinstance(v, numbers.Number) for v in col]).any() @@ -399,7 +421,7 @@ def is_not_numeric_column(col): "explanation as it contains non-numeric values." ) else: - shap.plots.scatter(self.explanation) + plot = shap.plots.scatter(self.explanation, show=show, **kwargs) else: # Type checking for mypy assert isinstance(self.bg_data, pd.DataFrame) @@ -412,4 +434,10 @@ def is_not_numeric_column(col): "Cannot plot SHAP scatter plot for all parameters as some contain " "non-numeric values." ) - shap.plots.scatter(self.explanation[:, number_enum], **kwargs) + plot = shap.plots.scatter( + self.explanation[:, number_enum], show=show, **kwargs + ) + + if not show: + return plot + return None From ff85048d8f0d9c5475e601ff99ce90649468aa9b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 7 Jan 2025 11:15:06 +0100 Subject: [PATCH 37/92] Drop Insights base class --- baybe/insights/base.py | 67 ------------------------------------------ baybe/insights/shap.py | 10 +++---- 2 files changed, 5 insertions(+), 72 deletions(-) delete mode 100644 baybe/insights/base.py diff --git a/baybe/insights/base.py b/baybe/insights/base.py deleted file mode 100644 index 513105d87..000000000 --- a/baybe/insights/base.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Base class for all insights.""" - -from __future__ import annotations - -from abc import ABC - -import pandas as pd -from attrs import define, field - -from baybe import Campaign -from baybe.objectives.base import Objective -from baybe.recommenders.pure.bayesian.base import BayesianRecommender -from baybe.searchspace import SearchSpace -from baybe.surrogates.base import SurrogateProtocol - - -@define -class Insight(ABC): - """Base class for all insights.""" - - surrogate: SurrogateProtocol = field() - """The surrogate model that is supposed bo be analyzed.""" - - @classmethod - def from_campaign(cls, campaign: Campaign) -> Insight: - """Create an insight from a campaign. - - Args: - campaign: A baybe Campaign object. - - Returns: - The Insight object. - """ - return cls(campaign.get_surrogate()) - - @classmethod - def from_recommender( - cls, - recommender: BayesianRecommender, - searchspace: SearchSpace, - objective: Objective, - measurements: pd.DataFrame, - ) -> Insight: - """Create an insight from a recommender. - - Args: - recommender: A model-based recommender. - searchspace: The search space used for recommendations. - objective: The objective of the recommendation. - measurements: The measurements in experimental representation. - - Returns: - The Insight object. - - Raises: - ValueError: If the provided recommender is not surrogate-based. - """ - if not hasattr(recommender, "get_surrogate"): - raise ValueError( - f"The provided recommender of type '{recommender.__class__.__name__}' " - f"does not provide a surrogate model." - ) - surrogate_model = recommender.get_surrogate( - searchspace, objective, measurements - ) - - return cls(surrogate_model) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index a95b460c4..b9cda13f9 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -12,14 +12,13 @@ import pandas as pd from attrs import Factory, define, field from attrs.validators import instance_of, optional -from typing_extensions import override from baybe import Campaign from baybe._optional.insights import shap -from baybe.insights.base import Insight from baybe.objectives.base import Objective from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace import SearchSpace +from baybe.surrogates.base import SurrogateProtocol from baybe.utils.dataframe import to_tensor @@ -78,12 +77,15 @@ def _has_required_init_parameters(cls): @define -class SHAPInsight(Insight): +class SHAPInsight: """Class for SHAP-based feature importance insights. This also supports LIME and MAPLE explainers via ways provided by the shap module. """ + surrogate: SurrogateProtocol = field() + """The surrogate model that is supposed bo be analyzed.""" + background_data: pd.DataFrame = field(validator=instance_of(pd.DataFrame)) """The background data set used to build the explainer.""" @@ -131,7 +133,6 @@ def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" return not self.explainer_cls.__module__.startswith("shap.explainers.other.") - @override @classmethod def from_campaign( cls, @@ -175,7 +176,6 @@ def from_campaign( explained_data=explained_data, ) - @override @classmethod def from_recommender( cls, From dd83e9968e76c9d0d2508e02c9ac238af162a027 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 7 Jan 2025 11:24:14 +0100 Subject: [PATCH 38/92] Turn explanation property into explain method --- baybe/insights/shap.py | 9 ++++----- tests/insights/test_shap.py | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index b9cda13f9..318f5cc07 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -350,8 +350,7 @@ def _init_explanation( f"invalid dimensionality of {len(explanations.shape)}." ) - @property - def explanation(self) -> shap.Explanation: + def explain(self) -> shap.Explanation: """Get the SHAP explanation object. Uses lazy evaluation.""" if self._explanation is None: self._explanation = self._init_explanation() @@ -396,7 +395,7 @@ def plot(self, plot_type: str, **kwargs: dict) -> None | plt.Axes: f"{SUPPORTED_SHAP_PLOTS}." ) - plot = plot_func(self.explanation, show=show, **kwargs) + plot = plot_func(self.explain(), show=show, **kwargs) if not show: return plot return None @@ -423,7 +422,7 @@ def is_not_numeric_column(col): "explanation as it contains non-numeric values." ) else: - plot = shap.plots.scatter(self.explanation, show=show, **kwargs) + plot = shap.plots.scatter(self.explain(), show=show, **kwargs) else: # Type checking for mypy assert isinstance(self.background_data, pd.DataFrame) @@ -437,7 +436,7 @@ def is_not_numeric_column(col): "non-numeric values." ) plot = shap.plots.scatter( - self.explanation[:, number_enum], show=show, **kwargs + self.explain()[:, number_enum], show=show, **kwargs ) if not show: diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index bfbba67fa..0483cd208 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -61,7 +61,7 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): assert shap_insight.uses_shap_explainer == is_shap # Sanity check explanation - shap_explanation = shap_insight.explanation + shap_explanation = shap_insight.explain() assert isinstance(shap_explanation, shap.Explanation) except TypeError as e: if "The selected explainer class" in str(e): @@ -138,11 +138,11 @@ def test_updated_campaign_explanations(campaign, n_iterations, batch_size): run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size) shap_insight = SHAPInsight.from_campaign(campaign) - explanation_1 = shap_insight.explanation + explanation_1 = shap_insight.explain() run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size) shap_insight = SHAPInsight.from_campaign(campaign) - explanation_2 = shap_insight.explanation + explanation_2 = shap_insight.explain() assert explanation_1 != explanation_2, "SHAP explanations should not be identical." From 740a4a1b566c0a76ad98b3b46a886bd07f8f3f66 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 7 Jan 2025 12:11:39 +0100 Subject: [PATCH 39/92] Pass data to be explained as method argument --- baybe/insights/shap.py | 71 +++++++++++++------------------------ tests/insights/test_shap.py | 21 +++++++---- 2 files changed, 40 insertions(+), 52 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 318f5cc07..299a388e2 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -11,7 +11,7 @@ import numpy.typing as npt import pandas as pd from attrs import Factory, define, field -from attrs.validators import instance_of, optional +from attrs.validators import instance_of from baybe import Campaign from baybe._optional.insights import shap @@ -89,11 +89,6 @@ class SHAPInsight: background_data: pd.DataFrame = field(validator=instance_of(pd.DataFrame)) """The background data set used to build the explainer.""" - explained_data: pd.DataFrame | None = field( - default=None, validator=optional(instance_of(pd.DataFrame)) - ) - """The data for which a SHAP explanation is generated.""" - # FIXME[typing]: https://github.com/python/mypy/issues/10998 explainer_cls: type[shap.Explainer] = field( # type: ignore[assignment] default="KernelExplainer", @@ -117,9 +112,6 @@ class SHAPInsight: ) """The explainer generated from the model and background data.""" - _explanation: shap.Explanation | None = field(default=None, init=False) - """The explanation generated.""" - @use_comp_rep.validator def _validate_use_comp_rep(self, _, value: bool) -> None: if not self.uses_shap_explainer and not value: @@ -137,7 +129,6 @@ def uses_shap_explainer(self) -> bool: def from_campaign( cls, campaign: Campaign, - explained_data: pd.DataFrame | None = None, explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, ) -> SHAPInsight: @@ -145,8 +136,6 @@ def from_campaign( Args: campaign: The campaign which holds the recommender and model. - explained_data: The data set to be explained. If None, all measurements - from the campaign are used. explainer_cls: The SHAP explainer class that is used to generate the explanation. use_comp_rep: @@ -173,7 +162,6 @@ def from_campaign( background_data=background_data, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, - explained_data=explained_data, ) @classmethod @@ -183,7 +171,6 @@ def from_recommender( searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame, - explained_data: pd.DataFrame | None = None, explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, ) -> SHAPInsight: @@ -195,8 +182,6 @@ def from_recommender( objective: The objective for the recommender. measurements: The background data set for Explainer. This is used the measurement data set for the recommender. - explained_data: The data set to be explained. If None, - the background data set is used. explainer_cls: The explainer class. use_comp_rep: Whether to analyze the model in computational representation @@ -214,7 +199,6 @@ def from_recommender( background_data=searchspace.transform(measurements) if use_comp_rep else measurements, - explained_data=explained_data, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) @@ -290,15 +274,11 @@ def model(x: npt.ArrayLike) -> np.ndarray: raise e return shap_explainer - def _init_explanation( - self, - explained_data: pd.DataFrame | None = None, - ) -> shap.Explanation: + def explain(self, df: pd.DataFrame, /) -> shap.Explanation: """Compute the Shapley values based on the chosen explainer and data set. Args: - explained_data: The data set for which the Shapley values should be - computed. + df: The data set for which the Shapley values should be computed. Returns: shap.Explanation: The computed Shapley values. @@ -307,9 +287,9 @@ def _init_explanation( ValueError: If the provided data set does not have the same amount of parameters as the SHAP explainer background """ - if explained_data is None: - explained_data = self.background_data - elif not self.background_data.shape[1] == explained_data.shape[1]: + if df is None: + df = self.background_data + elif not self.background_data.shape[1] == df.shape[1]: raise ValueError( "The provided data does not have the same amount of " "parameters as the shap explainer background." @@ -322,21 +302,21 @@ def _init_explanation( # Return attributions for non-SHAP explainers if self._explainer.__module__.endswith("maple"): # Additional argument for maple to increase comparability to SHAP - attributions = self._explainer.attributions( - explained_data, multiply_by_input=True - )[0] + attributions = self._explainer.attributions(df, multiply_by_input=True)[ + 0 + ] else: - attributions = self._explainer.attributions(explained_data)[0] + attributions = self._explainer.attributions(df)[0] explanations = shap.Explanation( values=attributions, base_values=self._explainer.model(self.background_data).mean(), - data=explained_data, - feature_names=explained_data.columns.values, + data=df, + feature_names=df.columns.values, ) return explanations else: - explanations = self._explainer(explained_data) + explanations = self._explainer(df) # Reduce dimensionality of explanations to 2D in case # a 3D explanation is returned. This is the case for @@ -350,17 +330,13 @@ def _init_explanation( f"invalid dimensionality of {len(explanations.shape)}." ) - def explain(self) -> shap.Explanation: - """Get the SHAP explanation object. Uses lazy evaluation.""" - if self._explanation is None: - self._explanation = self._init_explanation() - - return self._explanation - - def plot(self, plot_type: str, **kwargs: dict) -> None | plt.Axes: + def plot( + self, df: pd.DataFrame, /, plot_type: str, **kwargs: dict + ) -> None | plt.Axes: """Plot the Shapley values using the provided plot type. Args: + df: The data for which the Shapley values shall be plotted. plot_type: The type of plot to be created. Supported types are: "bar", "beeswarm", "force", "heatmap", "scatter". **kwargs: Additional keyword arguments to be passed to the plot function. @@ -378,7 +354,7 @@ def plot(self, plot_type: str, **kwargs: dict) -> None | plt.Axes: # Special case for scatter plot if plot_type == "scatter": - plot = self._plot_shap_scatter(show=show, **kwargs) + plot = self._plot_shap_scatter(df, show=show, **kwargs) if not show: return plot return None @@ -395,15 +371,18 @@ def plot(self, plot_type: str, **kwargs: dict) -> None | plt.Axes: f"{SUPPORTED_SHAP_PLOTS}." ) - plot = plot_func(self.explain(), show=show, **kwargs) + plot = plot_func(self.explain(df), show=show, **kwargs) if not show: return plot return None - def _plot_shap_scatter(self, show: bool = True, **kwargs: dict) -> None | plt.Axes: + def _plot_shap_scatter( + self, df: pd.DataFrame, /, show: bool = True, **kwargs: dict + ) -> None | plt.Axes: """Plot the Shapley values as scatter plot while leaving out string values. Args: + df: The data for which the Shapley values shall be plotted. show: Whether to call plt.show() after plotting or not. **kwargs: Additional keyword arguments to be passed to the plot function. @@ -422,7 +401,7 @@ def is_not_numeric_column(col): "explanation as it contains non-numeric values." ) else: - plot = shap.plots.scatter(self.explain(), show=show, **kwargs) + plot = shap.plots.scatter(self.explain(df), show=show, **kwargs) else: # Type checking for mypy assert isinstance(self.background_data, pd.DataFrame) @@ -436,7 +415,7 @@ def is_not_numeric_column(col): "non-numeric values." ) plot = shap.plots.scatter( - self.explain()[:, number_enum], show=show, **kwargs + self.explain(df)[:, number_enum], show=show, **kwargs ) if not show: diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 0483cd208..f9a469dfc 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -7,6 +7,7 @@ from pytest import mark from baybe._optional.info import SHAP_INSTALLED +from baybe.campaign import Campaign from tests.conftest import run_iterations # File-wide parameterization settings @@ -61,7 +62,10 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): assert shap_insight.uses_shap_explainer == is_shap # Sanity check explanation - shap_explanation = shap_insight.explain() + df = campaign.measurements[[p.name for p in campaign.parameters]] + if use_comp_rep: + df = campaign.searchspace.transform(df) + shap_explanation = shap_insight.explain(df) assert isinstance(shap_explanation, shap.Explanation) except TypeError as e: if "The selected explainer class" in str(e): @@ -112,20 +116,23 @@ def test_invalid_explained_data(ongoing_campaign, explainer_cls, use_comp_rep): match="The provided data does not have the same amount of parameters as the " "shap explainer background.", ): - shap_insight._init_explanation(df) + shap_insight.explain(df) @mark.slow @mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) @mark.parametrize("plot_type", SUPPORTED_SHAP_PLOTS) -def test_plots(ongoing_campaign, use_comp_rep, plot_type): +def test_plots(ongoing_campaign: Campaign, use_comp_rep, plot_type): """Test the default SHAP plots.""" shap_insight = SHAPInsight.from_campaign( ongoing_campaign, use_comp_rep=use_comp_rep, ) + df = ongoing_campaign.measurements[[p.name for p in ongoing_campaign.parameters]] + if use_comp_rep: + df = ongoing_campaign.searchspace.transform(df) with mock.patch("matplotlib.pyplot.show"): - shap_insight.plot(plot_type) + shap_insight.plot(df, plot_type=plot_type) def test_updated_campaign_explanations(campaign, n_iterations, batch_size): @@ -138,11 +145,13 @@ def test_updated_campaign_explanations(campaign, n_iterations, batch_size): run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size) shap_insight = SHAPInsight.from_campaign(campaign) - explanation_1 = shap_insight.explain() + df = campaign.measurements[[p.name for p in campaign.parameters]] + explanation_1 = shap_insight.explain(df) run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size) shap_insight = SHAPInsight.from_campaign(campaign) - explanation_2 = shap_insight.explain() + df = campaign.measurements[[p.name for p in campaign.parameters]] + explanation_2 = shap_insight.explain(df) assert explanation_1 != explanation_2, "SHAP explanations should not be identical." From 09ee6ac21ccd7b2a3ebc39dda8439d2aed14f192 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 7 Jan 2025 14:23:51 +0100 Subject: [PATCH 40/92] Extract explainer factory function --- baybe/insights/shap.py | 163 ++++++++++++++++++++++------------------- 1 file changed, 89 insertions(+), 74 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 299a388e2..42e7eeebb 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -18,7 +18,7 @@ from baybe.objectives.base import Objective from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace import SearchSpace -from baybe.surrogates.base import SurrogateProtocol +from baybe.surrogates.base import Surrogate, SurrogateProtocol from baybe.utils.dataframe import to_tensor @@ -76,6 +76,87 @@ def _has_required_init_parameters(cls): } +def is_shap_explainer(explainer_cls: type[shap.Explainer]) -> bool: + """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" + return not explainer_cls.__module__.startswith("shap.explainers.other.") + + +def _make_explainer( + surrogate: Surrogate, + data: pd.DataFrame, + explainer_cls: type[shap.Explainer] = shap.KernelExplainer, + use_comp_rep: bool = False, + **kwargs, +) -> shap.Explainer: + """Create a SHAP explainer. + + Args: + surrogate: The surrogate to be explained. + data: The background data set. + explainer_cls: The SHAP explainer class that is used to generate the + explanation. + use_comp_rep: Whether to analyze the model in computational representation + (experimental representation otherwise). + **kwargs: Additional keyword arguments to be passed to the explainer. + + Returns: + shap.Explainer: The created explainer object. + + Raises: + ValueError: If the provided background data set is empty. + TypeError: If the provided explainer class does not + support the campaign surrogate. + """ + if data.empty: + raise ValueError("The provided background data set is empty.") + + import torch + + if use_comp_rep: + + def model(x: npt.ArrayLike) -> np.ndarray: + tensor = to_tensor(x) + with torch.no_grad(): + output = surrogate._posterior_comp(tensor).mean + return output.numpy() + + else: + + def model(x: npt.ArrayLike) -> np.ndarray: + df = pd.DataFrame(x, columns=data.columns) + with torch.no_grad(): + output = surrogate.posterior(df).mean + return output.numpy() + + # Handle special settings + if "Lime" in explainer_cls.__name__: + # Lime default mode is otherwise set to 'classification' + kwargs["mode"] = "regression" + + try: + shap_explainer = explainer_cls(model, data, **kwargs) + + # Explain first two data points to ensure that the explainer is working + if is_shap_explainer(explainer_cls): + shap_explainer(data.iloc[0:1]) + except shap.utils._exceptions.InvalidModelError: + raise TypeError( + f"The selected explainer class {explainer_cls} does not support the " + f"provided surrogate model." + ) + except TypeError as e: + if "not supported for the input types" in str(e) and not use_comp_rep: + raise NotImplementedError( + f"The selected explainer class {explainer_cls} does not support " + f"the experimental representation. Switch to computational " + f"representation or use a different explainer (e.g. the default " + f"shap.KernelExplainer)." + ) + else: + raise e + return shap_explainer + + @define class SHAPInsight: """Class for SHAP-based feature importance insights. @@ -105,7 +186,12 @@ class SHAPInsight: _explainer: shap.Explainer | None = field( default=Factory( - lambda self: self._init_explainer(self.background_data, self.explainer_cls), + lambda self: _make_explainer( + self.surrogate, + self.background_data, + self.explainer_cls, + self.use_comp_rep, + ), takes_self=True, ), init=False, @@ -123,7 +209,7 @@ def _validate_use_comp_rep(self, _, value: bool) -> None: @property def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" - return not self.explainer_cls.__module__.startswith("shap.explainers.other.") + return is_shap_explainer(self._explainer) @classmethod def from_campaign( @@ -203,77 +289,6 @@ def from_recommender( use_comp_rep=use_comp_rep, ) - def _init_explainer( - self, - background_data: pd.DataFrame, - explainer_cls: type[shap.Explainer] = shap.KernelExplainer, - **kwargs, - ) -> shap.Explainer: - """Create a SHAP explainer. - - Args: - background_data: The background data set. - explainer_cls: The SHAP explainer class that is used to generate the - explanation. - **kwargs: Additional keyword arguments to be passed to the explainer. - - Returns: - shap.Explainer: The created explainer object. - - Raises: - ValueError: If the provided background data set is empty. - TypeError: If the provided explainer class does not - support the campaign surrogate. - """ - if background_data.empty: - raise ValueError("The provided background data set is empty.") - - import torch - - if self.use_comp_rep: - - def model(x: npt.ArrayLike) -> np.ndarray: - tensor = to_tensor(x) - with torch.no_grad(): - output = self.surrogate._posterior_comp(tensor).mean - return output.numpy() - - else: - - def model(x: npt.ArrayLike) -> np.ndarray: - df = pd.DataFrame(x, columns=background_data.columns) - with torch.no_grad(): - output = self.surrogate.posterior(df).mean - return output.numpy() - - # Handle special settings - if "Lime" in explainer_cls.__name__: - # Lime default mode is otherwise set to 'classification' - kwargs["mode"] = "regression" - - try: - shap_explainer = explainer_cls(model, background_data, **kwargs) - - # Explain first two data points to ensure that the explainer is working - if self.uses_shap_explainer: - shap_explainer(self.background_data.iloc[0:1]) - except shap.utils._exceptions.InvalidModelError: - raise TypeError( - f"The selected explainer class {explainer_cls} does not support the " - f"provided surrogate model." - ) - except TypeError as e: - if "not supported for the input types" in str(e) and not self.use_comp_rep: - raise NotImplementedError( - f"The selected explainer class {explainer_cls} does not support " - f"the experimental representation. Switch to computational " - f"representation or use a different explainer (e.g. the default " - f"shap.KernelExplainer)." - ) - else: - raise e - return shap_explainer - def explain(self, df: pd.DataFrame, /) -> shap.Explanation: """Compute the Shapley values based on the chosen explainer and data set. From c8940c0b089e6c2f19a136f7c9297ddff30d0369 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 7 Jan 2025 14:28:59 +0100 Subject: [PATCH 41/92] Add from_surrogate constructor --- baybe/insights/shap.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 42e7eeebb..c3ee51e4d 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -211,6 +211,22 @@ def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" return is_shap_explainer(self._explainer) + @classmethod + def from_surrogate( + cls, + surrogate: Surrogate, + data: pd.DataFrame, + explainer_cls: type[shap.Explainer] | str = "KernelExplainer", + use_comp_rep: bool = False, + ): + """Create a SHAP insight from a surrogate model.""" + return cls( + surrogate, + background_data=data, + explainer_cls=explainer_cls, + use_comp_rep=use_comp_rep, + ) + @classmethod def from_campaign( cls, @@ -243,9 +259,9 @@ def from_campaign( data = campaign.measurements[[p.name for p in campaign.parameters]].copy() background_data = campaign.searchspace.transform(data) if use_comp_rep else data - return cls( + return cls.from_surrogate( campaign.get_surrogate(), - background_data=background_data, + background_data, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) @@ -280,11 +296,9 @@ def from_recommender( searchspace, objective, measurements ) - return cls( + return cls.from_surrogate( surrogate_model, - background_data=searchspace.transform(measurements) - if use_comp_rep - else measurements, + searchspace.transform(measurements) if use_comp_rep else measurements, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) From 86c99d3fea3685b88063d23cc6e9b72f8b24a1f6 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 7 Jan 2025 15:16:45 +0100 Subject: [PATCH 42/92] Refactor class attributes Following the attrs mantra "what the class needs, not how it is instantiated" --- baybe/insights/shap.py | 80 ++++++++++--------------------------- tests/insights/test_shap.py | 2 +- 2 files changed, 22 insertions(+), 60 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index c3ee51e4d..3c5f11d9f 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -10,7 +10,7 @@ import numpy as np import numpy.typing as npt import pandas as pd -from attrs import Factory, define, field +from attrs import define, field from attrs.validators import instance_of from baybe import Campaign @@ -18,7 +18,7 @@ from baybe.objectives.base import Objective from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace import SearchSpace -from baybe.surrogates.base import Surrogate, SurrogateProtocol +from baybe.surrogates.base import Surrogate from baybe.utils.dataframe import to_tensor @@ -76,6 +76,11 @@ def _has_required_init_parameters(cls): } +def _convert_explainer_cls(x: type[shap.Explainer] | str) -> type[shap.Explainer]: + """Get an explainer class from an explainer class name (with class passthrough).""" + return ALL_EXPLAINERS[x] if isinstance(x, str) else x + + def is_shap_explainer(explainer_cls: type[shap.Explainer]) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" return not explainer_cls.__module__.startswith("shap.explainers.other.") @@ -84,7 +89,7 @@ def is_shap_explainer(explainer_cls: type[shap.Explainer]) -> bool: def _make_explainer( surrogate: Surrogate, data: pd.DataFrame, - explainer_cls: type[shap.Explainer] = shap.KernelExplainer, + explainer_cls: type[shap.Explainer] | str = shap.KernelExplainer, use_comp_rep: bool = False, **kwargs, ) -> shap.Explainer: @@ -110,6 +115,8 @@ def _make_explainer( if data.empty: raise ValueError("The provided background data set is empty.") + explainer_cls = _convert_explainer_cls(explainer_cls) + import torch if use_comp_rep: @@ -164,52 +171,16 @@ class SHAPInsight: This also supports LIME and MAPLE explainers via ways provided by the shap module. """ - surrogate: SurrogateProtocol = field() - """The surrogate model that is supposed bo be analyzed.""" + explainer: shap.Explainer = field(validator=instance_of(shap.Explainer)) + """The explainer instance.""" background_data: pd.DataFrame = field(validator=instance_of(pd.DataFrame)) """The background data set used to build the explainer.""" - # FIXME[typing]: https://github.com/python/mypy/issues/10998 - explainer_cls: type[shap.Explainer] = field( # type: ignore[assignment] - default="KernelExplainer", - converter=lambda x: ALL_EXPLAINERS[x] if isinstance(x, str) else x, - ) - """The SHAP explainer class that is used to generate the explanation. - - Some non-SHAP explainers, like MAPLE and LIME, are also supported if they are - available via 'shap.explainers.other'. - """ - - use_comp_rep: bool = field(default=False, validator=instance_of(bool)) - """Flag for toggling in which representation the insight should be provided.""" - - _explainer: shap.Explainer | None = field( - default=Factory( - lambda self: _make_explainer( - self.surrogate, - self.background_data, - self.explainer_cls, - self.use_comp_rep, - ), - takes_self=True, - ), - init=False, - ) - """The explainer generated from the model and background data.""" - - @use_comp_rep.validator - def _validate_use_comp_rep(self, _, value: bool) -> None: - if not self.uses_shap_explainer and not value: - raise NotImplementedError( - "Experimental representation is not supported for non-Kernel SHAP " - "explainer." - ) - @property def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" - return is_shap_explainer(self._explainer) + return is_shap_explainer(type(self.explainer)) @classmethod def from_surrogate( @@ -220,12 +191,8 @@ def from_surrogate( use_comp_rep: bool = False, ): """Create a SHAP insight from a surrogate model.""" - return cls( - surrogate, - background_data=data, - explainer_cls=explainer_cls, - use_comp_rep=use_comp_rep, - ) + explainer = _make_explainer(surrogate, data, explainer_cls, use_comp_rep) + return cls(explainer, data) @classmethod def from_campaign( @@ -316,36 +283,31 @@ def explain(self, df: pd.DataFrame, /) -> shap.Explanation: ValueError: If the provided data set does not have the same amount of parameters as the SHAP explainer background """ - if df is None: - df = self.background_data - elif not self.background_data.shape[1] == df.shape[1]: + if not self.background_data.shape[1] == df.shape[1]: raise ValueError( "The provided data does not have the same amount of " "parameters as the shap explainer background." ) - # Type checking for mypy - assert self._explainer is not None - if not self.uses_shap_explainer: # Return attributions for non-SHAP explainers - if self._explainer.__module__.endswith("maple"): + if self.explainer.__module__.endswith("maple"): # Additional argument for maple to increase comparability to SHAP - attributions = self._explainer.attributions(df, multiply_by_input=True)[ + attributions = self.explainer.attributions(df, multiply_by_input=True)[ 0 ] else: - attributions = self._explainer.attributions(df)[0] + attributions = self.explainer.attributions(df)[0] explanations = shap.Explanation( values=attributions, - base_values=self._explainer.model(self.background_data).mean(), + base_values=self.explainer.model(self.background_data).mean(), data=df, feature_names=df.columns.values, ) return explanations else: - explanations = self._explainer(df) + explanations = self.explainer(df) # Reduce dimensionality of explanations to 2D in case # a 3D explanation is returned. This is the case for diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index f9a469dfc..e62a830de 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -56,7 +56,7 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): ) assert isinstance(shap_insight, insights.SHAPInsight) assert isinstance( - shap_insight._explainer, + shap_insight.explainer, ALL_EXPLAINERS[explainer_cls], ) assert shap_insight.uses_shap_explainer == is_shap From 52db71bd3da975adab367a2d496aeadf97de28f3 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 7 Jan 2025 15:30:39 +0100 Subject: [PATCH 43/92] Drop duplicate input validation The validation needs to happen at the point where all construction paths cross, i.e. in the factory. --- baybe/insights/shap.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 3c5f11d9f..4eb8e64d5 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -213,16 +213,7 @@ def from_campaign( Returns: The SHAP insight object. - - Raises: - ValueError: If the campaign does not contain any measurements. """ - if campaign.measurements.empty: - raise ValueError( - f"The campaign does not contain any measurements. A '{cls.__name__}' " - f"assumes there is mandatory background data in the form of " - f"measurements as part of the campaign." - ) data = campaign.measurements[[p.name for p in campaign.parameters]].copy() background_data = campaign.searchspace.transform(data) if use_comp_rep else data From b03d3bf8620969d354a472bf54d4c13a78720b23 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 8 Jan 2025 14:34:44 +0100 Subject: [PATCH 44/92] Refactor plotting * Reduce code complexity * Fix signature * Return plot regardless of show flag * Raise error if plotting fails --- baybe/insights/shap.py | 103 +++++++++++++++-------------------------- 1 file changed, 38 insertions(+), 65 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 4eb8e64d5..954da9427 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -3,8 +3,8 @@ from __future__ import annotations import inspect -import numbers import warnings +from typing import Literal import matplotlib.pyplot as plt import numpy as np @@ -313,93 +313,66 @@ def explain(self, df: pd.DataFrame, /) -> shap.Explanation: ) def plot( - self, df: pd.DataFrame, /, plot_type: str, **kwargs: dict - ) -> None | plt.Axes: + self, + df: pd.DataFrame, + /, + plot_type: Literal["bar", "beeswarm", "force", "heatmap", "scatter"], + show: bool = True, + **kwargs: dict, + ) -> plt.Axes: """Plot the Shapley values using the provided plot type. Args: df: The data for which the Shapley values shall be plotted. - plot_type: The type of plot to be created. Supported types are: - "bar", "beeswarm", "force", "heatmap", "scatter". + plot_type: The type of plot to be created. + show: Boolean flag determining if the plot shall be rendered. **kwargs: Additional keyword arguments to be passed to the plot function. Returns: - None | plt.Axes: The plot object if 'show' is set to False. + The plot object. Raises: ValueError: If the provided plot type is not supported. """ - # Extract the 'show' argument from the kwargs - show = kwargs.pop("show", True) - - plot = None - - # Special case for scatter plot if plot_type == "scatter": - plot = self._plot_shap_scatter(df, show=show, **kwargs) - if not show: - return plot - return None - - # Cases for all other plots - plot_func = getattr(shap.plots, plot_type, None) - if ( - (plot_type not in SUPPORTED_SHAP_PLOTS) - or (plot_func is None) - or (not callable(plot_func)) - ): + return self._plot_shap_scatter(df, show=show, **kwargs) + + if plot_type not in SUPPORTED_SHAP_PLOTS: raise ValueError( - f"Invalid plot type: '{plot_type}'. Available options: " - f"{SUPPORTED_SHAP_PLOTS}." + f"Invalid plot type: '{plot_type}'. " + f"Available options: {SUPPORTED_SHAP_PLOTS}." ) + plot_func = getattr(shap.plots, plot_type) - plot = plot_func(self.explain(df), show=show, **kwargs) - if not show: - return plot - return None + return plot_func(self.explain(df), show=show, **kwargs) def _plot_shap_scatter( self, df: pd.DataFrame, /, show: bool = True, **kwargs: dict - ) -> None | plt.Axes: - """Plot the Shapley values as scatter plot while leaving out string values. + ) -> plt.Axes: + """Plot the Shapley values as scatter plot while leaving out non-numeric values. Args: df: The data for which the Shapley values shall be plotted. - show: Whether to call plt.show() after plotting or not. + show: Boolean flag determining if the plot shall be rendered. **kwargs: Additional keyword arguments to be passed to the plot function. Returns: - None | plt.Axes: The plot object if 'show' is set to False. - """ - plot = None - - def is_not_numeric_column(col): - return np.array([not isinstance(v, numbers.Number) for v in col]).any() + The plot object. - if np.ndim(self.background_data) == 1: - if is_not_numeric_column(self.background_data): - warnings.warn( - "Cannot plot scatter plot for the provided " - "explanation as it contains non-numeric values." - ) - else: - plot = shap.plots.scatter(self.explain(df), show=show, **kwargs) - else: - # Type checking for mypy - assert isinstance(self.background_data, pd.DataFrame) - - mask = self.background_data.iloc[0].apply(lambda x: not isinstance(x, str)) - number_enum = np.where(mask)[0].tolist() - - if len(number_enum) < len(self.background_data.iloc[0]): - warnings.warn( - "Cannot plot SHAP scatter plot for all parameters as some contain " - "non-numeric values." - ) - plot = shap.plots.scatter( - self.explain(df)[:, number_enum], show=show, **kwargs + Raises: + ValueError: If no plot can be created because of non-numeric data. + """ + df_numeric = df.select_dtypes("number") + numeric_idx = df.columns.get_indexer(df_numeric.columns) + if df_numeric.empty: + raise ValueError( + "No SHAP scatter plot can be created since all features contain " + "non-numeric values." ) - - if not show: - return plot - return None + if non_numeric_cols := set(df.columns) - set(df_numeric.columns): + warnings.warn( + f"The following features are excluded from the SHAP scatter plot " + f"because their contain non-numeric values: {non_numeric_cols}", + UserWarning, + ) + return shap.plots.scatter(self.explain(df)[:, numeric_idx], show=show, **kwargs) From d4aeaf4c1282731fabcf5c558a39cacba0fd29c9 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 8 Jan 2025 14:38:19 +0100 Subject: [PATCH 45/92] Define default explainer class --- baybe/insights/shap.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 954da9427..d7c1b397f 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -21,6 +21,8 @@ from baybe.surrogates.base import Surrogate from baybe.utils.dataframe import to_tensor +_DEFAULT_EXPLAINER_CLS = "KernelExplainer" + def _get_explainer_maps() -> ( tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]] @@ -89,7 +91,7 @@ def is_shap_explainer(explainer_cls: type[shap.Explainer]) -> bool: def _make_explainer( surrogate: Surrogate, data: pd.DataFrame, - explainer_cls: type[shap.Explainer] | str = shap.KernelExplainer, + explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, use_comp_rep: bool = False, **kwargs, ) -> shap.Explainer: @@ -187,7 +189,7 @@ def from_surrogate( cls, surrogate: Surrogate, data: pd.DataFrame, - explainer_cls: type[shap.Explainer] | str = "KernelExplainer", + explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, use_comp_rep: bool = False, ): """Create a SHAP insight from a surrogate model.""" @@ -198,7 +200,7 @@ def from_surrogate( def from_campaign( cls, campaign: Campaign, - explainer_cls: type[shap.Explainer] | str = "KernelExplainer", + explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, use_comp_rep: bool = False, ) -> SHAPInsight: """Create a SHAP insight from a campaign. From 331746d4cb6140e08111c7c6e1c1231e6d50879d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 8 Jan 2025 14:48:57 +0100 Subject: [PATCH 46/92] Make data to be explained optional --- baybe/insights/shap.py | 23 ++++++++++++++++++----- tests/insights/test_shap.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index d7c1b397f..b50ff7c23 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -263,11 +263,12 @@ def from_recommender( use_comp_rep=use_comp_rep, ) - def explain(self, df: pd.DataFrame, /) -> shap.Explanation: + def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: """Compute the Shapley values based on the chosen explainer and data set. Args: df: The data set for which the Shapley values should be computed. + By default, the background data of the explainer is used. Returns: shap.Explanation: The computed Shapley values. @@ -276,6 +277,9 @@ def explain(self, df: pd.DataFrame, /) -> shap.Explanation: ValueError: If the provided data set does not have the same amount of parameters as the SHAP explainer background """ + if df is None: + df = self.background_data + if not self.background_data.shape[1] == df.shape[1]: raise ValueError( "The provided data does not have the same amount of " @@ -316,17 +320,19 @@ def explain(self, df: pd.DataFrame, /) -> shap.Explanation: def plot( self, - df: pd.DataFrame, - /, plot_type: Literal["bar", "beeswarm", "force", "heatmap", "scatter"], + df: pd.DataFrame | None = None, + /, + *, show: bool = True, **kwargs: dict, ) -> plt.Axes: """Plot the Shapley values using the provided plot type. Args: - df: The data for which the Shapley values shall be plotted. plot_type: The type of plot to be created. + df: The data for which the Shapley values shall be plotted. + By default, the background data of the explainer is used. show: Boolean flag determining if the plot shall be rendered. **kwargs: Additional keyword arguments to be passed to the plot function. @@ -336,6 +342,9 @@ def plot( Raises: ValueError: If the provided plot type is not supported. """ + if df is None: + df = self.background_data + if plot_type == "scatter": return self._plot_shap_scatter(df, show=show, **kwargs) @@ -349,12 +358,13 @@ def plot( return plot_func(self.explain(df), show=show, **kwargs) def _plot_shap_scatter( - self, df: pd.DataFrame, /, show: bool = True, **kwargs: dict + self, df: pd.DataFrame | None = None, /, *, show: bool = True, **kwargs: dict ) -> plt.Axes: """Plot the Shapley values as scatter plot while leaving out non-numeric values. Args: df: The data for which the Shapley values shall be plotted. + By default, the background data of the explainer is used. show: Boolean flag determining if the plot shall be rendered. **kwargs: Additional keyword arguments to be passed to the plot function. @@ -364,6 +374,9 @@ def _plot_shap_scatter( Raises: ValueError: If no plot can be created because of non-numeric data. """ + if df is None: + df = self.background_data + df_numeric = df.select_dtypes("number") numeric_idx = df.columns.get_indexer(df_numeric.columns) if df_numeric.empty: diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index e62a830de..ec8dcedb4 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -132,7 +132,7 @@ def test_plots(ongoing_campaign: Campaign, use_comp_rep, plot_type): if use_comp_rep: df = ongoing_campaign.searchspace.transform(df) with mock.patch("matplotlib.pyplot.show"): - shap_insight.plot(df, plot_type=plot_type) + shap_insight.plot(plot_type, df) def test_updated_campaign_explanations(campaign, n_iterations, batch_size): From 2dc0e941bd72c37a1b17f0751b433a8beaf24553 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 8 Jan 2025 14:53:53 +0100 Subject: [PATCH 47/92] Drop converter utility --- baybe/insights/shap.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index b50ff7c23..c58822ebd 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -78,11 +78,6 @@ def _has_required_init_parameters(cls): } -def _convert_explainer_cls(x: type[shap.Explainer] | str) -> type[shap.Explainer]: - """Get an explainer class from an explainer class name (with class passthrough).""" - return ALL_EXPLAINERS[x] if isinstance(x, str) else x - - def is_shap_explainer(explainer_cls: type[shap.Explainer]) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" return not explainer_cls.__module__.startswith("shap.explainers.other.") @@ -117,7 +112,8 @@ def _make_explainer( if data.empty: raise ValueError("The provided background data set is empty.") - explainer_cls = _convert_explainer_cls(explainer_cls) + if isinstance(explainer_cls, str): + explainer_cls = ALL_EXPLAINERS[explainer_cls] import torch From 31d95f4167e652c0671b8a7516592eda006b3915 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 09:36:49 +0100 Subject: [PATCH 48/92] Refactor explainer sets --- baybe/insights/shap.py | 84 +++++++++++-------------------------- tests/insights/test_shap.py | 33 +++++---------- tests/insights/test_tmp.py | 22 ++++++++++ 3 files changed, 57 insertions(+), 82 deletions(-) create mode 100644 tests/insights/test_tmp.py diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index c58822ebd..0686f2af6 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -2,7 +2,6 @@ from __future__ import annotations -import inspect import warnings from typing import Literal @@ -22,65 +21,30 @@ from baybe.utils.dataframe import to_tensor _DEFAULT_EXPLAINER_CLS = "KernelExplainer" +SHAP_EXPLAINERS = { + "AdditiveExplainer", + "ExactExplainer", + "KernelExplainer", + "PartitionExplainer", + "PermutationExplainer", +} +NON_SHAP_EXPLAINERS = {"LimeTabular", "Maple"} +EXPLAINERS = {*SHAP_EXPLAINERS, *NON_SHAP_EXPLAINERS} +SHAP_PLOTS = {"bar", "beeswarm", "heatmap", "scatter"} -def _get_explainer_maps() -> ( - tuple[dict[str, type[shap.Explainer]], dict[str, type[shap.Explainer]]] -): - """Get maps for SHAP and non-SHAP explainers. - - Returns: - The maps for SHAP and non-SHAP explainers. - """ - EXCLUDED_EXPLAINER_KEYWORDS = [ - "Tree", - "GPU", - "Gradient", - "Sampling", - "Deep", - "Linear", - ] - - def _has_required_init_parameters(cls): - """Check if non-shap initializer has required standard parameters.""" - REQUIRED_PARAMETERS = ["self", "model", "data"] - - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - - return parameters[:3] == REQUIRED_PARAMETERS - - shap_explainers = { - cls_name: getattr(shap.explainers, cls_name) - for cls_name in shap.explainers.__all__ - if all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - } - - non_shap_explainers = { - cls_name: explainer - for cls_name in shap.explainers.other.__all__ - if _has_required_init_parameters( - explainer := getattr(shap.explainers.other, cls_name) - ) - and all(x not in cls_name for x in EXCLUDED_EXPLAINER_KEYWORDS) - } - - return shap_explainers, non_shap_explainers - - -SHAP_EXPLAINERS, NON_SHAP_EXPLAINERS = _get_explainer_maps() -ALL_EXPLAINERS = SHAP_EXPLAINERS | NON_SHAP_EXPLAINERS -SUPPORTED_SHAP_PLOTS = { - "bar", - "beeswarm", - "heatmap", - "scatter", -} +def _get_explainer_cls(name: str) -> type[shap.Explainer]: + """Retrieve the explainer class reference by name.""" + if name in SHAP_EXPLAINERS: + return getattr(shap.explainers, name) + if name in NON_SHAP_EXPLAINERS: + return getattr(shap.explainers.other, name) + raise ValueError(f"Unknown SHAP explainer class '{name}'.") -def is_shap_explainer(explainer_cls: type[shap.Explainer]) -> bool: +def is_shap_explainer(explainer: shap.Explainer) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" - return not explainer_cls.__module__.startswith("shap.explainers.other.") + return type(explainer).__name__ in SHAP_EXPLAINERS def _make_explainer( @@ -113,7 +77,7 @@ def _make_explainer( raise ValueError("The provided background data set is empty.") if isinstance(explainer_cls, str): - explainer_cls = ALL_EXPLAINERS[explainer_cls] + explainer_cls = _get_explainer_cls(explainer_cls) import torch @@ -142,7 +106,7 @@ def model(x: npt.ArrayLike) -> np.ndarray: shap_explainer = explainer_cls(model, data, **kwargs) # Explain first two data points to ensure that the explainer is working - if is_shap_explainer(explainer_cls): + if is_shap_explainer(shap_explainer): shap_explainer(data.iloc[0:1]) except shap.utils._exceptions.InvalidModelError: raise TypeError( @@ -178,7 +142,7 @@ class SHAPInsight: @property def uses_shap_explainer(self) -> bool: """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" - return is_shap_explainer(type(self.explainer)) + return is_shap_explainer(self.explainer) @classmethod def from_surrogate( @@ -344,10 +308,10 @@ def plot( if plot_type == "scatter": return self._plot_shap_scatter(df, show=show, **kwargs) - if plot_type not in SUPPORTED_SHAP_PLOTS: + if plot_type not in SHAP_PLOTS: raise ValueError( f"Invalid plot type: '{plot_type}'. " - f"Available options: {SUPPORTED_SHAP_PLOTS}." + f"Available options: {SHAP_PLOTS}." ) plot_func = getattr(shap.plots, plot_type) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index ec8dcedb4..4a19c1c73 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -6,8 +6,17 @@ import pytest from pytest import mark +from baybe import insights from baybe._optional.info import SHAP_INSTALLED +from baybe._optional.insights import shap from baybe.campaign import Campaign +from baybe.insights.shap import ( + NON_SHAP_EXPLAINERS, + SHAP_EXPLAINERS, + SHAP_PLOTS, + SHAPInsight, + _get_explainer_cls, +) from tests.conftest import run_iterations # File-wide parameterization settings @@ -27,23 +36,6 @@ ] -if SHAP_INSTALLED: - from baybe import insights - from baybe._optional.insights import shap - from baybe.insights.shap import ( - ALL_EXPLAINERS, - NON_SHAP_EXPLAINERS, - SHAP_EXPLAINERS, - SUPPORTED_SHAP_PLOTS, - SHAPInsight, - ) -else: - ALL_EXPLAINERS = [] - NON_SHAP_EXPLAINERS = [] - SHAP_EXPLAINERS = [] - SUPPORTED_SHAP_PLOTS = [] - - def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): """Helper function for general SHAP explainer tests.""" # run_iterations(campaign, n_iterations=2, batch_size=5) @@ -55,10 +47,7 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): use_comp_rep=use_comp_rep, ) assert isinstance(shap_insight, insights.SHAPInsight) - assert isinstance( - shap_insight.explainer, - ALL_EXPLAINERS[explainer_cls], - ) + assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls)) assert shap_insight.uses_shap_explainer == is_shap # Sanity check explanation @@ -121,7 +110,7 @@ def test_invalid_explained_data(ongoing_campaign, explainer_cls, use_comp_rep): @mark.slow @mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"]) -@mark.parametrize("plot_type", SUPPORTED_SHAP_PLOTS) +@mark.parametrize("plot_type", SHAP_PLOTS) def test_plots(ongoing_campaign: Campaign, use_comp_rep, plot_type): """Test the default SHAP plots.""" shap_insight = SHAPInsight.from_campaign( diff --git a/tests/insights/test_tmp.py b/tests/insights/test_tmp.py new file mode 100644 index 000000000..4687c99a1 --- /dev/null +++ b/tests/insights/test_tmp.py @@ -0,0 +1,22 @@ +"""Temporary test file.""" + +import inspect + +import pytest +import shap + +from baybe.insights.shap import NON_SHAP_EXPLAINERS, _get_explainer_cls + + +def _has_required_init_parameters(cls: type[shap.Explainer]) -> bool: + """Check if non-shap initializer has required standard parameters.""" + REQUIRED_PARAMETERS = ["self", "model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == REQUIRED_PARAMETERS + + +@pytest.mark.parametrize("explainer_name", NON_SHAP_EXPLAINERS) +def test_non_shap_signature(explainer_name): + """Non-SHAP explainers must have the required signature.""" + assert _has_required_init_parameters(_get_explainer_cls(explainer_name)) From 84b5f1808a4c6eb9ceb5037c4178396373ae9e93 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 09:43:42 +0100 Subject: [PATCH 49/92] Add column permutation test --- tests/insights/test_tmp.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/insights/test_tmp.py b/tests/insights/test_tmp.py index 4687c99a1..bd9d0d93f 100644 --- a/tests/insights/test_tmp.py +++ b/tests/insights/test_tmp.py @@ -2,10 +2,13 @@ import inspect +import numpy as np +import pandas as pd import pytest import shap +from shap.explainers import KernelExplainer -from baybe.insights.shap import NON_SHAP_EXPLAINERS, _get_explainer_cls +from baybe.insights.shap import NON_SHAP_EXPLAINERS, SHAPInsight, _get_explainer_cls def _has_required_init_parameters(cls: type[shap.Explainer]) -> bool: @@ -20,3 +23,22 @@ def _has_required_init_parameters(cls: type[shap.Explainer]) -> bool: def test_non_shap_signature(explainer_name): """Non-SHAP explainers must have the required signature.""" assert _has_required_init_parameters(_get_explainer_cls(explainer_name)) + + +def test_column_permutation(): + """Explaining data with permuted columns gives permuted explanations.""" + N = 10 + + # Create insights object and test data + background_data = pd.DataFrame(np.random.random((N, 3)), columns=["x", "y", "z"]) + explainer = KernelExplainer(lambda x: x, background_data) + insights = SHAPInsight(explainer, background_data) + df = pd.DataFrame(np.random.random((N, 3)), columns=["x", "y", "z"]) + + # Regular column order + ex1 = insights.explain(df) + + # Permuted column order + ex2 = insights.explain(df[["z", "x", "y"]])[:, [1, 2, 0]] + + assert np.array_equal(ex1.values, ex2.values) From 151207b8db54bc6e59cf9da452631492390daee1 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 09:53:08 +0100 Subject: [PATCH 50/92] Make test pass by permuting columns --- baybe/insights/shap.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 0686f2af6..427c8e72f 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -239,32 +239,37 @@ def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: """ if df is None: df = self.background_data - - if not self.background_data.shape[1] == df.shape[1]: + elif set(self.background_data.columns) != set(df.columns): raise ValueError( - "The provided data does not have the same amount of " - "parameters as the shap explainer background." + "The provided dataframe must have the same column names as used by " + "the explainer object." ) + # Align columns with background data + df_aligned = df[self.background_data.columns] + if not self.uses_shap_explainer: # Return attributions for non-SHAP explainers if self.explainer.__module__.endswith("maple"): # Additional argument for maple to increase comparability to SHAP - attributions = self.explainer.attributions(df, multiply_by_input=True)[ - 0 - ] + attributions = self.explainer.attributions( + df_aligned, multiply_by_input=True + )[0] else: - attributions = self.explainer.attributions(df)[0] + attributions = self.explainer.attributions(df_aligned)[0] explanations = shap.Explanation( values=attributions, base_values=self.explainer.model(self.background_data).mean(), - data=df, - feature_names=df.columns.values, + data=df_aligned, + feature_names=df_aligned.columns.values, ) - return explanations else: - explanations = self.explainer(df) + explanations = self.explainer(df_aligned) + + # Permute explanation object according to input column order + idx = self.background_data.columns.get_indexer(df.columns) + explanations = explanations[:, idx] # Reduce dimensionality of explanations to 2D in case # a 3D explanation is returned. This is the case for From 8750a0f282852e2c2fa4641c6b11546a9acdae32 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 13:40:33 +0100 Subject: [PATCH 51/92] Rework docstrings --- baybe/insights/shap.py | 107 ++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 60 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 427c8e72f..6f89888dd 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -43,35 +43,31 @@ def _get_explainer_cls(name: str) -> type[shap.Explainer]: def is_shap_explainer(explainer: shap.Explainer) -> bool: - """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" + """Indicate if the given explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" return type(explainer).__name__ in SHAP_EXPLAINERS -def _make_explainer( +def make_explainer_for_surrogate( surrogate: Surrogate, data: pd.DataFrame, explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, use_comp_rep: bool = False, - **kwargs, ) -> shap.Explainer: - """Create a SHAP explainer. + """Create a SHAP explainer for a given surrogate model. Args: - surrogate: The surrogate to be explained. + surrogate: The surrogate model to be explained. data: The background data set. - explainer_cls: The SHAP explainer class that is used to generate the - explanation. - use_comp_rep: Whether to analyze the model in computational representation - (experimental representation otherwise). - **kwargs: Additional keyword arguments to be passed to the explainer. + explainer_cls: The SHAP explainer class for generating the explanation. + use_comp_rep: Boolean flag specifying whether to explain the model's + experimental or computational representation. Returns: - shap.Explainer: The created explainer object. + The created explainer object. Raises: ValueError: If the provided background data set is empty. - TypeError: If the provided explainer class does not - support the campaign surrogate. + TypeError: If the provided explainer class is incompatible with the surrogate. """ if data.empty: raise ValueError("The provided background data set is empty.") @@ -97,10 +93,8 @@ def model(x: npt.ArrayLike) -> np.ndarray: output = surrogate.posterior(df).mean return output.numpy() - # Handle special settings - if "Lime" in explainer_cls.__name__: - # Lime default mode is otherwise set to 'classification' - kwargs["mode"] = "regression" + # Handle special settings: Lime default mode is otherwise set to "classification" + kwargs = {"mode": "regression"} if explainer_cls.__name__ == "LimeTabular" else {} try: shap_explainer = explainer_cls(model, data, **kwargs) @@ -130,18 +124,18 @@ def model(x: npt.ArrayLike) -> np.ndarray: class SHAPInsight: """Class for SHAP-based feature importance insights. - This also supports LIME and MAPLE explainers via ways provided by the shap module. + Also supports LIME and MAPLE explainers via the ``shap`` module. """ explainer: shap.Explainer = field(validator=instance_of(shap.Explainer)) """The explainer instance.""" background_data: pd.DataFrame = field(validator=instance_of(pd.DataFrame)) - """The background data set used to build the explainer.""" + """The background data set used by the explainer.""" @property def uses_shap_explainer(self) -> bool: - """Whether the explainer is a SHAP explainer or not (e.g. MAPLE, LIME).""" + """Indicates if a SHAP explainer is used or not (e.g. MAPLE, LIME).""" return is_shap_explainer(self.explainer) @classmethod @@ -152,8 +146,13 @@ def from_surrogate( explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, use_comp_rep: bool = False, ): - """Create a SHAP insight from a surrogate model.""" - explainer = _make_explainer(surrogate, data, explainer_cls, use_comp_rep) + """Create a SHAP insight from a campaign. + + For details, see :func:`make_explainer_for_surrogate`. + """ + explainer = make_explainer_for_surrogate( + surrogate, data, explainer_cls, use_comp_rep + ) return cls(explainer, data) @classmethod @@ -165,13 +164,12 @@ def from_campaign( ) -> SHAPInsight: """Create a SHAP insight from a campaign. + Uses the measurements of the campaign as background data. + Args: - campaign: The campaign which holds the recommender and model. - explainer_cls: The SHAP explainer class that is used to generate the - explanation. - use_comp_rep: - Whether to analyze the model in computational representation - (experimental representation otherwise). + campaign: A campaign holding a recommender using a surrogate model. + explainer_cls: See :func:`make_explainer_for_surrogate. + use_comp_rep: See :func:`make_explainer_for_surrogate. Returns: The SHAP insight object. @@ -198,16 +196,16 @@ def from_recommender( ) -> SHAPInsight: """Create a SHAP insight from a recommender. + Uses the provided measurements to train the surrogate and as background data for + the explainer. + Args: - recommender: The model-based recommender. + recommender: A recommender using a surrogate model. searchspace: The searchspace for the recommender. objective: The objective for the recommender. - measurements: The background data set for Explainer. - This is used the measurement data set for the recommender. - explainer_cls: The explainer class. - use_comp_rep: - Whether to analyze the model in computational representation - (experimental representation otherwise). + measurements: The measurements for training the surrogate and the explainer. + explainer_cls: See :func:`make_explainer_for_surrogate. + use_comp_rep: See :func:`make_explainer_for_surrogate. Returns: The SHAP insight object. @@ -224,18 +222,18 @@ def from_recommender( ) def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: - """Compute the Shapley values based on the chosen explainer and data set. + """Compute a Shapley explanation for a given data set. Args: - df: The data set for which the Shapley values should be computed. - By default, the background data of the explainer is used. + df: The dataframe for which the Shapley values are to be computed. + By default, the background data set of the explainer is used. Returns: - shap.Explanation: The computed Shapley values. + The computed Shapley explanation. Raises: - ValueError: If the provided data set does not have the same amount of - parameters as the SHAP explainer background + ValueError: If the columns of the given dataframe cannot be aligned with the + columns of the explainer background dataframe. """ if df is None: df = self.background_data @@ -279,8 +277,8 @@ def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: if len(explanations.shape) == 3: return explanations[:, :, 0] raise RuntimeError( - f"The explanation obtained for {self.__class__.__name__} has an unexpected " - f"invalid dimensionality of {len(explanations.shape)}." + f"The explanation obtained for '{self.__class__.__name__}' has an " + f"unexpected dimensionality of {len(explanations.shape)}." ) def plot( @@ -296,10 +294,9 @@ def plot( Args: plot_type: The type of plot to be created. - df: The data for which the Shapley values shall be plotted. - By default, the background data of the explainer is used. - show: Boolean flag determining if the plot shall be rendered. - **kwargs: Additional keyword arguments to be passed to the plot function. + df: See :meth:`explain`. + show: Boolean flag determining if the plot is to be rendered. + **kwargs: Additional keyword arguments passed to the plot function. Returns: The plot object. @@ -325,19 +322,9 @@ def plot( def _plot_shap_scatter( self, df: pd.DataFrame | None = None, /, *, show: bool = True, **kwargs: dict ) -> plt.Axes: - """Plot the Shapley values as scatter plot while leaving out non-numeric values. + """Plot the Shapley values as scatter plot, ignoring non-numeric features. - Args: - df: The data for which the Shapley values shall be plotted. - By default, the background data of the explainer is used. - show: Boolean flag determining if the plot shall be rendered. - **kwargs: Additional keyword arguments to be passed to the plot function. - - Returns: - The plot object. - - Raises: - ValueError: If no plot can be created because of non-numeric data. + For details, see :meth:`explain`. """ if df is None: df = self.background_data @@ -352,7 +339,7 @@ def _plot_shap_scatter( if non_numeric_cols := set(df.columns) - set(df_numeric.columns): warnings.warn( f"The following features are excluded from the SHAP scatter plot " - f"because their contain non-numeric values: {non_numeric_cols}", + f"because they contain non-numeric values: {non_numeric_cols}", UserWarning, ) return shap.plots.scatter(self.explain(df)[:, numeric_idx], show=show, **kwargs) From b748a5718e98a07e6627d5dd4a5143fbe9de02e8 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 13:43:48 +0100 Subject: [PATCH 52/92] Avoid unnecessary data copy --- baybe/insights/shap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 6f89888dd..928539dd3 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -174,7 +174,7 @@ def from_campaign( Returns: The SHAP insight object. """ - data = campaign.measurements[[p.name for p in campaign.parameters]].copy() + data = campaign.measurements[[p.name for p in campaign.parameters]] background_data = campaign.searchspace.transform(data) if use_comp_rep else data return cls.from_surrogate( From 6e81c9716b8a6ab6432bf8b93e8eb1c84660d3c1 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 14:32:00 +0100 Subject: [PATCH 53/92] Fix optional dependency handling --- baybe/_optional/info.py | 5 +++-- pytest.ini | 1 - tests/insights/test_shap.py | 8 ++++++-- tests/insights/test_tmp.py | 6 ++++++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/baybe/_optional/info.py b/baybe/_optional/info.py index 0381c3e36..fdedb38e6 100644 --- a/baybe/_optional/info.py +++ b/baybe/_optional/info.py @@ -25,12 +25,13 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 # Individual packages with exclude_sys_path(os.getcwd()): FLAKE8_INSTALLED = find_spec("flake8") is not None + LIME_INSTALLED = find_spec("lime") is not None ONNX_INSTALLED = find_spec("onnxruntime") is not None POLARS_INSTALLED = find_spec("polars") is not None - SHAP_INSTALLED = find_spec("shap") is not None PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None RUFF_INSTALLED = find_spec("ruff") is not None + SHAP_INSTALLED = find_spec("shap") is not None SKFP_INSTALLED = find_spec("skfp") is not None # scikit-fingerprints STREAMLIT_INSTALLED = find_spec("streamlit") is not None XYZPY_INSTALLED = find_spec("xyzpy") is not None @@ -45,7 +46,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404 # Information on whether all required packages for certain functionality are available CHEM_INSTALLED = SKFP_INSTALLED -INSIGHTS_INSTALLED = SHAP_INSTALLED +INSIGHTS_INSTALLED = SHAP_INSTALLED and LIME_INSTALLED LINT_INSTALLED = all( ( FLAKE8_INSTALLED, diff --git a/pytest.ini b/pytest.ini index 2dc951c4d..c993cc465 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,6 @@ addopts = --ignore=baybe/_optional --ignore=baybe/utils/chemistry.py --ignore=tests/simulate_telemetry.py - --ignore=baybe/insights testpaths = baybe tests \ No newline at end of file diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 4a19c1c73..86ceb3cdf 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -6,8 +6,13 @@ import pytest from pytest import mark +from baybe._optional.info import INSIGHTS_INSTALLED + +if not INSIGHTS_INSTALLED: + pytest.skip("Optional insights package not installed.", allow_module_level=True) + + from baybe import insights -from baybe._optional.info import SHAP_INSTALLED from baybe._optional.insights import shap from baybe.campaign import Campaign from baybe.insights.shap import ( @@ -21,7 +26,6 @@ # File-wide parameterization settings pytestmark = [ - mark.skipif(not SHAP_INSTALLED, reason="Optional shap package not installed."), mark.parametrize("n_grid_points", [5], ids=["g5"]), mark.parametrize("n_iterations", [2], ids=["i2"]), mark.parametrize("batch_size", [2], ids=["b2"]), diff --git a/tests/insights/test_tmp.py b/tests/insights/test_tmp.py index bd9d0d93f..ba9e09c5a 100644 --- a/tests/insights/test_tmp.py +++ b/tests/insights/test_tmp.py @@ -5,6 +5,12 @@ import numpy as np import pandas as pd import pytest + +from baybe._optional.info import INSIGHTS_INSTALLED + +if not INSIGHTS_INSTALLED: + pytest.skip("Optional insights package not installed.", allow_module_level=True) + import shap from shap.explainers import KernelExplainer From 4ab58988e8130676eec3b83c8e00b74b475a5d54 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 14:34:13 +0100 Subject: [PATCH 54/92] Use shap's optional dependency group --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index baa03143c..55f9b9e5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,8 +95,7 @@ dev = [ ] insights = [ - "shap>=0.46.0", - "lime>=0.2.0.1" + "shap[others]>=0.46.0", ] docs = [ From c819d97b89f38c19033fab00e3fcb8b42318ec4d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 14:39:11 +0100 Subject: [PATCH 55/92] Update lockfile --- .lockfiles/py310-dev.lock | 79 ++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/.lockfiles/py310-dev.lock b/.lockfiles/py310-dev.lock index 5e4f963a5..003b461b2 100644 --- a/.lockfiles/py310-dev.lock +++ b/.lockfiles/py310-dev.lock @@ -8,7 +8,7 @@ anyio==4.4.0 # via # httpx # jupyter-server -appnope==0.1.4 ; platform_system == 'Darwin' +appnope==0.1.4 ; sys_platform == 'darwin' # via ipykernel argon2-cffi==23.1.0 # via jupyter-server @@ -137,7 +137,6 @@ docutils==0.21.2 # myst-parser # pybtex-docutils # sphinx - # sphinx-paramlinks # sphinxcontrib-bibtex e3fp==1.2.5 # via scikit-fingerprints @@ -232,6 +231,8 @@ idna==3.7 # httpx # jsonschema # requests +imageio==2.36.1 + # via scikit-image imagesize==1.4.1 # via sphinx importlib-metadata==7.1.0 @@ -240,7 +241,7 @@ importlib-metadata==7.1.0 # opentelemetry-api iniconfig==2.0.0 # via pytest -intel-openmp==2021.4.0 ; platform_system == 'Windows' +intel-openmp==2021.4.0 ; sys_platform == 'win32' # via mkl interface-meta==1.3.0 # via formulaic @@ -346,18 +347,20 @@ kiwisolver==1.4.5 # via matplotlib latexcodec==3.0.0 # via pybtex +lazy-loader==0.4 + # via scikit-image license-expression==30.3.0 # via cyclonedx-python-lib lifelines==0.29.0 # via ngboost +lime==0.2.0.1 + # via shap linear-operator==0.5.2 # via # botorch # gpytorch llvmlite==0.43.0 # via numba -llvmlite==0.43.0 - # via numba locket==1.0.0 # via partd markdown-it-py==3.0.0 @@ -374,6 +377,7 @@ matplotlib==3.9.1 # via # baybe (pyproject.toml) # lifelines + # lime # seaborn # types-seaborn matplotlib-inline==0.1.7 @@ -390,7 +394,7 @@ mdurl==0.1.2 # via markdown-it-py mistune==3.0.2 # via nbconvert -mkl==2021.4.0 ; platform_system == 'Windows' +mkl==2021.4.0 ; sys_platform == 'win32' # via torch mmh3==5.0.1 # via e3fp @@ -428,6 +432,7 @@ nest-asyncio==1.6.0 networkx==3.3 # via # mordredcommunity + # scikit-image # torch ngboost==0.5.1 # via baybe (pyproject.toml) @@ -440,9 +445,9 @@ notebook-shim==0.2.4 # jupyterlab # notebook numba==0.60.0 - # via shap -numba==0.60.0 - # via scikit-fingerprints + # via + # scikit-fingerprints + # shap numpy==1.26.4 # via # baybe (pyproject.toml) @@ -455,12 +460,13 @@ numpy==1.26.4 # e3fp # formulaic # h5py + # imageio # lifelines + # lime # matplotlib # mordredcommunity # ngboost # numba - # numba # onnx # onnxconverter-common # onnxruntime @@ -472,45 +478,47 @@ numpy==1.26.4 # pyro-ppl # rdkit # scikit-fingerprints + # scikit-image # scikit-learn # scikit-learn-extra # scipy # seaborn # shap # streamlit + # tifffile # types-seaborn # xarray # xyzpy -nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cusolver-cu12 # torch -nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 -nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch onnx==1.16.1 # via @@ -574,6 +582,7 @@ packaging==24.1 # jupyterlab # jupyterlab-server # jupytext + # lazy-loader # matplotlib # mordredcommunity # nbconvert @@ -584,6 +593,7 @@ packaging==24.1 # plotly # pyproject-api # pytest + # scikit-image # setuptools-scm # shap # sphinx @@ -622,8 +632,10 @@ pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' pillow==10.4.0 # via # baybe (pyproject.toml) + # imageio # matplotlib # rdkit + # scikit-image # streamlit pip==24.1.2 # via pip-api @@ -794,10 +806,13 @@ s3transfer==0.10.4 # via boto3 scikit-fingerprints==1.9.0 # via baybe (pyproject.toml) +scikit-image==0.25.0 + # via lime scikit-learn==1.5.1 # via # baybe (pyproject.toml) # gpytorch + # lime # ngboost # scikit-fingerprints # scikit-learn-extra @@ -816,14 +831,16 @@ scipy==1.14.0 # formulaic # gpytorch # lifelines + # lime # linear-operator # ngboost # scikit-fingerprints + # scikit-image # scikit-learn # scikit-learn-extra + # shap sdaxen-python-utilities==0.1.5 # via e3fp - # shap seaborn==0.13.2 # via baybe (pyproject.toml) send2trash==1.8.3 @@ -847,10 +864,10 @@ six==1.16.0 # rfc3339-validator skl2onnx==1.17.0 # via baybe (pyproject.toml) -smart-open==7.0.5 - # via e3fp slicer==0.0.8 # via shap +smart-open==7.0.5 + # via e3fp smmap==5.0.1 # via gitdb sniffio==1.3.1 @@ -873,7 +890,6 @@ sphinx==8.1.3 # sphinx-autodoc-typehints # sphinx-basic-ng # sphinx-copybutton - # sphinx-paramlinks # sphinxcontrib-bibtex sphinx-autodoc-typehints==2.5.0 # via baybe (pyproject.toml) @@ -881,8 +897,6 @@ sphinx-basic-ng==1.0.0b2 # via furo sphinx-copybutton==0.5.2 # via baybe (pyproject.toml) -sphinx-paramlinks==0.6.0 - # via baybe (pyproject.toml) sphinxcontrib-applehelp==1.0.8 # via sphinx sphinxcontrib-bibtex==2.6.2 @@ -905,7 +919,7 @@ sympy==1.13.1 # via # onnxruntime # torch -tbb==2021.13.0 ; platform_system == 'Windows' +tbb==2021.13.0 ; sys_platform == 'win32' # via mkl tenacity==8.5.0 # via @@ -918,6 +932,8 @@ terminado==0.18.1 # jupyter-server-terminals threadpoolctl==3.5.0 # via scikit-learn +tifffile==2024.12.12 + # via scikit-image tinycss2==1.3.0 # via nbconvert tokenize-rt==6.1.0 @@ -966,6 +982,7 @@ tox-uv==1.9.1 tqdm==4.66.4 # via # huggingface-hub + # lime # ngboost # pyro-ppl # scikit-fingerprints @@ -987,7 +1004,7 @@ traitlets==5.14.3 # nbclient # nbconvert # nbformat -triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and platform_system == 'Linux' +triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux' # via torch typeguard==2.13.3 # via @@ -1030,7 +1047,7 @@ virtualenv==20.26.3 # via # pre-commit # tox -watchdog==4.0.1 ; platform_system != 'Darwin' +watchdog==4.0.1 ; sys_platform != 'darwin' # via streamlit wcwidth==0.2.13 # via prompt-toolkit From 3296933ece31dfb32e9ec7f2b3f58e9516af739d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 14:47:58 +0100 Subject: [PATCH 56/92] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ec51e55f9..37fa2ddd0 100644 --- a/README.md +++ b/README.md @@ -30,10 +30,12 @@ The **Bay**esian **B**ack **E**nd (**BayBE**) is a general-purpose toolbox for Bayesian Design of Experiments, focusing on additions that enable real-world experimental campaigns. +## ๐Ÿ”‹ Batteries Included Besides functionality to perform a typical recommend-measure loop, BayBE's highlights are: - โœจ Custom parameter encodings: Improve your campaign with domain knowledge - ๐Ÿงช Built-in chemical encodings: Improve your campaign with chemical knowledge - ๐ŸŽฏ Single and multiple targets with min, max and match objectives +- ๐Ÿ” Built-in analysis tools: Gain insights into feature importance and model behavior - ๐ŸŽญ Hybrid (mixed continuous and discrete) spaces - ๐Ÿš€ Transfer learning: Mix data from multiple campaigns and accelerate optimization - ๐ŸŽฐ Bandit models: Efficiently find the best among many options in noisy environments (e.g. A/B Testing) @@ -297,7 +299,7 @@ The available groups are: - `mypy`: Required for static type checking. - `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). - `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/). -- `insights`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/). +- `insights`: Required for built-in model and campaign analysis (e.g. using [SHAP](https://shap.readthedocs.io/)). - `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. - `test`: Required for running the tests. - `benchmarking`: Required for running the benchmarking module. From 9650397fb2811d8f90c4518ad4f66379831ff759 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 15:07:41 +0100 Subject: [PATCH 57/92] Refactor insights/__init__.py --- baybe/insights/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/baybe/insights/__init__.py b/baybe/insights/__init__.py index 767d0288a..857694610 100644 --- a/baybe/insights/__init__.py +++ b/baybe/insights/__init__.py @@ -1,10 +1,9 @@ """Baybe insights (optional).""" -from baybe._optional.info import INSIGHTS_INSTALLED - -if INSIGHTS_INSTALLED: - from baybe.insights.shap import SHAPInsight +from baybe.insights.shap import EXPLAINERS, SHAP_PLOTS, SHAPInsight __all__ = [ + "EXPLAINERS", + "SHAP_PLOTS", "SHAPInsight", ] From 1d380759b753c6fc5bc21f325ec59c1bf07b83df Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 14:47:58 +0100 Subject: [PATCH 58/92] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ec51e55f9..37fa2ddd0 100644 --- a/README.md +++ b/README.md @@ -30,10 +30,12 @@ The **Bay**esian **B**ack **E**nd (**BayBE**) is a general-purpose toolbox for Bayesian Design of Experiments, focusing on additions that enable real-world experimental campaigns. +## ๐Ÿ”‹ Batteries Included Besides functionality to perform a typical recommend-measure loop, BayBE's highlights are: - โœจ Custom parameter encodings: Improve your campaign with domain knowledge - ๐Ÿงช Built-in chemical encodings: Improve your campaign with chemical knowledge - ๐ŸŽฏ Single and multiple targets with min, max and match objectives +- ๐Ÿ” Built-in analysis tools: Gain insights into feature importance and model behavior - ๐ŸŽญ Hybrid (mixed continuous and discrete) spaces - ๐Ÿš€ Transfer learning: Mix data from multiple campaigns and accelerate optimization - ๐ŸŽฐ Bandit models: Efficiently find the best among many options in noisy environments (e.g. A/B Testing) @@ -297,7 +299,7 @@ The available groups are: - `mypy`: Required for static type checking. - `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). - `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/). -- `insights`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/). +- `insights`: Required for built-in model and campaign analysis (e.g. using [SHAP](https://shap.readthedocs.io/)). - `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. - `test`: Required for running the tests. - `benchmarking`: Required for running the benchmarking module. From bec5fc58de9702e71b9f9914b8af783f6c44f2f0 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 15:07:41 +0100 Subject: [PATCH 59/92] Refactor insights/__init__.py --- baybe/insights/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/baybe/insights/__init__.py b/baybe/insights/__init__.py index 767d0288a..857694610 100644 --- a/baybe/insights/__init__.py +++ b/baybe/insights/__init__.py @@ -1,10 +1,9 @@ """Baybe insights (optional).""" -from baybe._optional.info import INSIGHTS_INSTALLED - -if INSIGHTS_INSTALLED: - from baybe.insights.shap import SHAPInsight +from baybe.insights.shap import EXPLAINERS, SHAP_PLOTS, SHAPInsight __all__ = [ + "EXPLAINERS", + "SHAP_PLOTS", "SHAPInsight", ] From 2f5c96eaafd0e6283800020ed4ad094cf3a68d00 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 19:20:52 +0100 Subject: [PATCH 60/92] Add missing sphinx-paramlinks doc dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 55f9b9e5a..8597ebc79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ docs = [ "sphinx>=8.0.2", "sphinx-autodoc-typehints>=2.4.4", "sphinx-copybutton==0.5.2", + "sphinx-paramlinks==0.6.0", "sphinxcontrib-bibtex>=2.6.2", ] From b65731b79370f511d8f90fac01ba1ec59b041bb5 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 14:39:11 +0100 Subject: [PATCH 61/92] Update lockfile --- .lockfiles/py310-dev.lock | 75 +++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/.lockfiles/py310-dev.lock b/.lockfiles/py310-dev.lock index 5e4f963a5..7d5e0c1e8 100644 --- a/.lockfiles/py310-dev.lock +++ b/.lockfiles/py310-dev.lock @@ -8,7 +8,7 @@ anyio==4.4.0 # via # httpx # jupyter-server -appnope==0.1.4 ; platform_system == 'Darwin' +appnope==0.1.4 ; sys_platform == 'darwin' # via ipykernel argon2-cffi==23.1.0 # via jupyter-server @@ -232,6 +232,8 @@ idna==3.7 # httpx # jsonschema # requests +imageio==2.36.1 + # via scikit-image imagesize==1.4.1 # via sphinx importlib-metadata==7.1.0 @@ -240,7 +242,7 @@ importlib-metadata==7.1.0 # opentelemetry-api iniconfig==2.0.0 # via pytest -intel-openmp==2021.4.0 ; platform_system == 'Windows' +intel-openmp==2021.4.0 ; sys_platform == 'win32' # via mkl interface-meta==1.3.0 # via formulaic @@ -346,18 +348,20 @@ kiwisolver==1.4.5 # via matplotlib latexcodec==3.0.0 # via pybtex +lazy-loader==0.4 + # via scikit-image license-expression==30.3.0 # via cyclonedx-python-lib lifelines==0.29.0 # via ngboost +lime==0.2.0.1 + # via shap linear-operator==0.5.2 # via # botorch # gpytorch llvmlite==0.43.0 # via numba -llvmlite==0.43.0 - # via numba locket==1.0.0 # via partd markdown-it-py==3.0.0 @@ -374,6 +378,7 @@ matplotlib==3.9.1 # via # baybe (pyproject.toml) # lifelines + # lime # seaborn # types-seaborn matplotlib-inline==0.1.7 @@ -390,7 +395,7 @@ mdurl==0.1.2 # via markdown-it-py mistune==3.0.2 # via nbconvert -mkl==2021.4.0 ; platform_system == 'Windows' +mkl==2021.4.0 ; sys_platform == 'win32' # via torch mmh3==5.0.1 # via e3fp @@ -428,6 +433,7 @@ nest-asyncio==1.6.0 networkx==3.3 # via # mordredcommunity + # scikit-image # torch ngboost==0.5.1 # via baybe (pyproject.toml) @@ -440,9 +446,9 @@ notebook-shim==0.2.4 # jupyterlab # notebook numba==0.60.0 - # via shap -numba==0.60.0 - # via scikit-fingerprints + # via + # scikit-fingerprints + # shap numpy==1.26.4 # via # baybe (pyproject.toml) @@ -455,12 +461,13 @@ numpy==1.26.4 # e3fp # formulaic # h5py + # imageio # lifelines + # lime # matplotlib # mordredcommunity # ngboost # numba - # numba # onnx # onnxconverter-common # onnxruntime @@ -472,45 +479,47 @@ numpy==1.26.4 # pyro-ppl # rdkit # scikit-fingerprints + # scikit-image # scikit-learn # scikit-learn-extra # scipy # seaborn # shap # streamlit + # tifffile # types-seaborn # xarray # xyzpy -nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cusolver-cu12 # torch -nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 -nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux' +nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch onnx==1.16.1 # via @@ -574,6 +583,7 @@ packaging==24.1 # jupyterlab # jupyterlab-server # jupytext + # lazy-loader # matplotlib # mordredcommunity # nbconvert @@ -584,6 +594,7 @@ packaging==24.1 # plotly # pyproject-api # pytest + # scikit-image # setuptools-scm # shap # sphinx @@ -622,8 +633,10 @@ pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' pillow==10.4.0 # via # baybe (pyproject.toml) + # imageio # matplotlib # rdkit + # scikit-image # streamlit pip==24.1.2 # via pip-api @@ -794,10 +807,13 @@ s3transfer==0.10.4 # via boto3 scikit-fingerprints==1.9.0 # via baybe (pyproject.toml) +scikit-image==0.25.0 + # via lime scikit-learn==1.5.1 # via # baybe (pyproject.toml) # gpytorch + # lime # ngboost # scikit-fingerprints # scikit-learn-extra @@ -816,14 +832,16 @@ scipy==1.14.0 # formulaic # gpytorch # lifelines + # lime # linear-operator # ngboost # scikit-fingerprints + # scikit-image # scikit-learn # scikit-learn-extra + # shap sdaxen-python-utilities==0.1.5 # via e3fp - # shap seaborn==0.13.2 # via baybe (pyproject.toml) send2trash==1.8.3 @@ -847,10 +865,10 @@ six==1.16.0 # rfc3339-validator skl2onnx==1.17.0 # via baybe (pyproject.toml) -smart-open==7.0.5 - # via e3fp slicer==0.0.8 # via shap +smart-open==7.0.5 + # via e3fp smmap==5.0.1 # via gitdb sniffio==1.3.1 @@ -905,7 +923,7 @@ sympy==1.13.1 # via # onnxruntime # torch -tbb==2021.13.0 ; platform_system == 'Windows' +tbb==2021.13.0 ; sys_platform == 'win32' # via mkl tenacity==8.5.0 # via @@ -918,6 +936,8 @@ terminado==0.18.1 # jupyter-server-terminals threadpoolctl==3.5.0 # via scikit-learn +tifffile==2024.12.12 + # via scikit-image tinycss2==1.3.0 # via nbconvert tokenize-rt==6.1.0 @@ -966,6 +986,7 @@ tox-uv==1.9.1 tqdm==4.66.4 # via # huggingface-hub + # lime # ngboost # pyro-ppl # scikit-fingerprints @@ -987,7 +1008,7 @@ traitlets==5.14.3 # nbclient # nbconvert # nbformat -triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and platform_system == 'Linux' +triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux' # via torch typeguard==2.13.3 # via @@ -1030,7 +1051,7 @@ virtualenv==20.26.3 # via # pre-commit # tox -watchdog==4.0.1 ; platform_system != 'Darwin' +watchdog==4.0.1 ; sys_platform != 'darwin' # via streamlit wcwidth==0.2.13 # via prompt-toolkit From 118de872f456d693dcc9f60007d29ffa2dee9111 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 9 Jan 2025 20:44:59 +0100 Subject: [PATCH 62/92] Fix BayBE spelling --- baybe/insights/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/insights/__init__.py b/baybe/insights/__init__.py index 857694610..c0bcf0fd3 100644 --- a/baybe/insights/__init__.py +++ b/baybe/insights/__init__.py @@ -1,4 +1,4 @@ -"""Baybe insights (optional).""" +"""BayBE insights (optional).""" from baybe.insights.shap import EXPLAINERS, SHAP_PLOTS, SHAPInsight From a1b058cc76538c804d1990abab5eb5091177362f Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 10 Jan 2025 13:01:49 +0100 Subject: [PATCH 63/92] fixup! Rework docstrings --- baybe/insights/shap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 928539dd3..15a5e10d4 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -229,7 +229,7 @@ def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: By default, the background data set of the explainer is used. Returns: - The computed Shapley explanation. + The computed Shapley explanation. Raises: ValueError: If the columns of the given dataframe cannot be aligned with the From 11e01f4a2adfe056e991bb31cfa8e0f416b4497e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 10 Jan 2025 13:05:16 +0100 Subject: [PATCH 64/92] Fix mypy issues --- baybe/insights/shap.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 15a5e10d4..e406ac20a 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt import numpy as np -import numpy.typing as npt import pandas as pd from attrs import define, field from attrs.validators import instance_of @@ -17,7 +16,7 @@ from baybe.objectives.base import Objective from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace import SearchSpace -from baybe.surrogates.base import Surrogate +from baybe.surrogates.base import Surrogate, SurrogateProtocol from baybe.utils.dataframe import to_tensor _DEFAULT_EXPLAINER_CLS = "KernelExplainer" @@ -79,7 +78,7 @@ def make_explainer_for_surrogate( if use_comp_rep: - def model(x: npt.ArrayLike) -> np.ndarray: + def model(x: np.ndarray) -> np.ndarray: tensor = to_tensor(x) with torch.no_grad(): output = surrogate._posterior_comp(tensor).mean @@ -87,7 +86,7 @@ def model(x: npt.ArrayLike) -> np.ndarray: else: - def model(x: npt.ArrayLike) -> np.ndarray: + def model(x: np.ndarray) -> np.ndarray: df = pd.DataFrame(x, columns=data.columns) with torch.no_grad(): output = surrogate.posterior(df).mean @@ -141,7 +140,7 @@ def uses_shap_explainer(self) -> bool: @classmethod def from_surrogate( cls, - surrogate: Surrogate, + surrogate: SurrogateProtocol, data: pd.DataFrame, explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, use_comp_rep: bool = False, @@ -150,6 +149,12 @@ def from_surrogate( For details, see :func:`make_explainer_for_surrogate`. """ + if not isinstance(surrogate, Surrogate): + raise ValueError( + f"'{cls.__name__}.{cls.from_surrogate.__name__}' only accepts " + f"surrogate models of type '{Surrogate.__name__}' or its subclasses." + ) + explainer = make_explainer_for_surrogate( surrogate, data, explainer_cls, use_comp_rep ) From a9a0dd1a09e4112aa2753899b2e1d877f709158e Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Sun, 12 Jan 2025 21:40:35 +0100 Subject: [PATCH 65/92] Fixed permutation of explanation object. Reintroduced ValueError when initiliazing SHAPInsight from campaign without measurements. --- baybe/insights/shap.py | 14 +++++++++++--- tests/insights/test_shap.py | 4 ++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index e406ac20a..72c71a621 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -178,7 +178,12 @@ def from_campaign( Returns: The SHAP insight object. + + Raises: + ValueError: If the campaign does not contain any measurements. """ + if campaign.measurements.empty: + raise ValueError("The campaign does not contain any measurements.") data = campaign.measurements[[p.name for p in campaign.parameters]] background_data = campaign.searchspace.transform(data) if use_comp_rep else data @@ -264,15 +269,18 @@ def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: explanations = shap.Explanation( values=attributions, base_values=self.explainer.model(self.background_data).mean(), - data=df_aligned, + data=df_aligned.values, feature_names=df_aligned.columns.values, ) else: explanations = self.explainer(df_aligned) - # Permute explanation object according to input column order + # Permute explanation object data according to input column order. + # Do not do this for the base_values as it can be a scalar. idx = self.background_data.columns.get_indexer(df.columns) - explanations = explanations[:, idx] + for attr in ["values", "data"]: + setattr(explanations, attr, getattr(explanations, attr)[:, idx]) + explanations.feature_names = [explanations.feature_names[i] for i in idx] # Reduce dimensionality of explanations to 2D in case # a 3D explanation is returned. This is the case for diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 86ceb3cdf..7aa4f39b0 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -106,8 +106,8 @@ def test_invalid_explained_data(ongoing_campaign, explainer_cls, use_comp_rep): df = pd.DataFrame({"Num_disc_1": [0, 2]}) with pytest.raises( ValueError, - match="The provided data does not have the same amount of parameters as the " - "shap explainer background.", + match="The provided dataframe must have the same column names as used by " + "the explainer object.", ): shap_insight.explain(df) From c6421c5d6d7daee95537225ca725bd688c655a6b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 13 Jan 2025 15:59:39 +0100 Subject: [PATCH 66/92] Simplify set union --- baybe/insights/shap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 72c71a621..4787cc2f5 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -28,7 +28,7 @@ "PermutationExplainer", } NON_SHAP_EXPLAINERS = {"LimeTabular", "Maple"} -EXPLAINERS = {*SHAP_EXPLAINERS, *NON_SHAP_EXPLAINERS} +EXPLAINERS = SHAP_EXPLAINERS | NON_SHAP_EXPLAINERS SHAP_PLOTS = {"bar", "beeswarm", "heatmap", "scatter"} From a447c42357fe2d473c8c3c5405bad2711d42bbe5 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 13 Jan 2025 17:17:43 +0100 Subject: [PATCH 67/92] Rename df method argument to data --- baybe/insights/shap.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 4787cc2f5..79dd46e33 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -231,11 +231,11 @@ def from_recommender( use_comp_rep=use_comp_rep, ) - def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: + def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation: """Compute a Shapley explanation for a given data set. Args: - df: The dataframe for which the Shapley values are to be computed. + data: The dataframe for which the Shapley values are to be computed. By default, the background data set of the explainer is used. Returns: @@ -245,16 +245,16 @@ def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: ValueError: If the columns of the given dataframe cannot be aligned with the columns of the explainer background dataframe. """ - if df is None: - df = self.background_data - elif set(self.background_data.columns) != set(df.columns): + if data is None: + data = self.background_data + elif set(self.background_data.columns) != set(data.columns): raise ValueError( "The provided dataframe must have the same column names as used by " "the explainer object." ) # Align columns with background data - df_aligned = df[self.background_data.columns] + df_aligned = data[self.background_data.columns] if not self.uses_shap_explainer: # Return attributions for non-SHAP explainers @@ -277,7 +277,7 @@ def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: # Permute explanation object data according to input column order. # Do not do this for the base_values as it can be a scalar. - idx = self.background_data.columns.get_indexer(df.columns) + idx = self.background_data.columns.get_indexer(data.columns) for attr in ["values", "data"]: setattr(explanations, attr, getattr(explanations, attr)[:, idx]) explanations.feature_names = [explanations.feature_names[i] for i in idx] @@ -297,7 +297,7 @@ def explain(self, df: pd.DataFrame | None = None, /) -> shap.Explanation: def plot( self, plot_type: Literal["bar", "beeswarm", "force", "heatmap", "scatter"], - df: pd.DataFrame | None = None, + data: pd.DataFrame | None = None, /, *, show: bool = True, @@ -307,7 +307,7 @@ def plot( Args: plot_type: The type of plot to be created. - df: See :meth:`explain`. + data: See :meth:`explain`. show: Boolean flag determining if the plot is to be rendered. **kwargs: Additional keyword arguments passed to the plot function. @@ -317,11 +317,11 @@ def plot( Raises: ValueError: If the provided plot type is not supported. """ - if df is None: - df = self.background_data + if data is None: + data = self.background_data if plot_type == "scatter": - return self._plot_shap_scatter(df, show=show, **kwargs) + return self._plot_shap_scatter(data, show=show, **kwargs) if plot_type not in SHAP_PLOTS: raise ValueError( @@ -330,29 +330,31 @@ def plot( ) plot_func = getattr(shap.plots, plot_type) - return plot_func(self.explain(df), show=show, **kwargs) + return plot_func(self.explain(data), show=show, **kwargs) def _plot_shap_scatter( - self, df: pd.DataFrame | None = None, /, *, show: bool = True, **kwargs: dict + self, data: pd.DataFrame | None = None, /, *, show: bool = True, **kwargs: dict ) -> plt.Axes: """Plot the Shapley values as scatter plot, ignoring non-numeric features. For details, see :meth:`explain`. """ - if df is None: - df = self.background_data + if data is None: + data = self.background_data - df_numeric = df.select_dtypes("number") - numeric_idx = df.columns.get_indexer(df_numeric.columns) + df_numeric = data.select_dtypes("number") + numeric_idx = data.columns.get_indexer(df_numeric.columns) if df_numeric.empty: raise ValueError( "No SHAP scatter plot can be created since all features contain " "non-numeric values." ) - if non_numeric_cols := set(df.columns) - set(df_numeric.columns): + if non_numeric_cols := set(data.columns) - set(df_numeric.columns): warnings.warn( f"The following features are excluded from the SHAP scatter plot " f"because they contain non-numeric values: {non_numeric_cols}", UserWarning, ) - return shap.plots.scatter(self.explain(df)[:, numeric_idx], show=show, **kwargs) + return shap.plots.scatter( + self.explain(data)[:, numeric_idx], show=show, **kwargs + ) From b44f47f23375594d621668412b391e04f58b45c2 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Mon, 13 Jan 2025 20:17:57 +0100 Subject: [PATCH 68/92] Removed double docstring --- tests/insights/test_shap.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 7aa4f39b0..a8e42f9ca 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -86,8 +86,7 @@ def test_shap_explainers(ongoing_campaign, explainer_cls, use_comp_rep): @mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS) def test_non_shap_explainers(ongoing_campaign, explainer_cls): - """Test the explain functionalities with the non-SHAP explainer MAPLE.""" - """Test the non-SHAP explainer in computational representation.""" + """Test the non-SHAP explainers in computational representation.""" _test_shap_insight( ongoing_campaign, explainer_cls, use_comp_rep=True, is_shap=False ) From 03c95f7f532fe7eaac0e45d934aa46cd6ce437dd Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Mon, 13 Jan 2025 20:20:32 +0100 Subject: [PATCH 69/92] Test cleanup --- tests/insights/test_shap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index a8e42f9ca..f03e8e0b4 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -42,7 +42,6 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): """Helper function for general SHAP explainer tests.""" - # run_iterations(campaign, n_iterations=2, batch_size=5) try: # Sanity check explainer shap_insight = SHAPInsight.from_campaign( From e075b9eaf89eeeae81611a6170eb20e4100c2703 Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Mon, 13 Jan 2025 20:45:40 +0100 Subject: [PATCH 70/92] Improved comments and docstring --- baybe/insights/shap.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 79dd46e33..c30385862 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -145,7 +145,7 @@ def from_surrogate( explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, use_comp_rep: bool = False, ): - """Create a SHAP insight from a campaign. + """Create a SHAP insight from a surrogate. For details, see :func:`make_explainer_for_surrogate`. """ @@ -320,6 +320,7 @@ def plot( if data is None: data = self.background_data + # Use custom scatter plot function to ignore non-numeric features if plot_type == "scatter": return self._plot_shap_scatter(data, show=show, **kwargs) From bf9cf2bd61ab31c2c8e026411f427090a879891b Mon Sep 17 00:00:00 2001 From: "Wieczorek, Alexander" Date: Mon, 13 Jan 2025 20:48:30 +0100 Subject: [PATCH 71/92] Filter data to measurement parameters only when initializing from recommender --- baybe/insights/shap.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index c30385862..d087d53d6 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -224,9 +224,11 @@ def from_recommender( searchspace, objective, measurements ) + data = measurements[[p.name for p in searchspace.parameters]] + return cls.from_surrogate( surrogate_model, - searchspace.transform(measurements) if use_comp_rep else measurements, + searchspace.transform(data) if use_comp_rep else data, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) From f89f9efd327f498edda958cf23637a50f8a83f54 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 09:25:19 +0100 Subject: [PATCH 72/92] Validate that explainer is of accepted type --- baybe/insights/shap.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index d087d53d6..8c40f5e88 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -132,6 +132,15 @@ class SHAPInsight: background_data: pd.DataFrame = field(validator=instance_of(pd.DataFrame)) """The background data set used by the explainer.""" + @explainer.validator + def _validate_explainer(self, _, explainer: shap.Explainer) -> None: + """Validate the explainer type.""" + if (name := explainer.__class__.__name__) not in EXPLAINERS: + raise ValueError( + f"The given explainer type must be one of {EXPLAINERS}. " + f"Given: '{name}'." + ) + @property def uses_shap_explainer(self) -> bool: """Indicates if a SHAP explainer is used or not (e.g. MAPLE, LIME).""" From 8282b9ae0f25432cb3582a83b79cda49eb6fee38 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 09:06:11 +0100 Subject: [PATCH 73/92] Avoid test collection import errors due to optional dependencies --- baybe/exceptions.py | 4 +++- pytest.ini | 10 +++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/baybe/exceptions.py b/baybe/exceptions.py index 661f61a97..75f6bd440 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -2,6 +2,8 @@ ##### Warnings ##### + + class UnusedObjectWarning(UserWarning): """ A method or function was called with undesired arguments which indicates an @@ -59,7 +61,7 @@ class NumericalUnderflowError(Exception): """A computation would lead to numerical underflow.""" -class OptionalImportError(Exception): +class OptionalImportError(ImportError): """An attempt was made to import an optional but uninstalled dependency.""" diff --git a/pytest.ini b/pytest.ini index c993cc465..3699b8770 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,13 +3,9 @@ addopts = --doctest-modules --ignore=examples --ignore=docs - - ; TODO: The following modules are ignored due to optional dependencies, which - ; otherwise break test collection in core test environment. - ; Probably, there is a more elegant solution to it. - --ignore=baybe/_optional - --ignore=baybe/utils/chemistry.py - --ignore=tests/simulate_telemetry.py + + ; Avoids import errors due to optional dependencies + --doctest-ignore-import-errors testpaths = baybe tests \ No newline at end of file From 856d4fd53b764687e29c43327d4f74e137364a82 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 10:17:34 +0100 Subject: [PATCH 74/92] Drop try-except branch for unsupported explainers types Not needed because the list of supported explainers is hard-coded and all explainer types are tested in the insights test module --- baybe/insights/shap.py | 5 ----- tests/insights/test_shap.py | 11 +---------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 8c40f5e88..56f3ade3a 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -101,11 +101,6 @@ def model(x: np.ndarray) -> np.ndarray: # Explain first two data points to ensure that the explainer is working if is_shap_explainer(shap_explainer): shap_explainer(data.iloc[0:1]) - except shap.utils._exceptions.InvalidModelError: - raise TypeError( - f"The selected explainer class {explainer_cls} does not support the " - f"provided surrogate model." - ) except TypeError as e: if "not supported for the input types" in str(e) and not use_comp_rep: raise NotImplementedError( diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index f03e8e0b4..9ac618af8 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -59,20 +59,11 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): df = campaign.searchspace.transform(df) shap_explanation = shap_insight.explain(df) assert isinstance(shap_explanation, shap.Explanation) - except TypeError as e: + except NotImplementedError as e: if "The selected explainer class" in str(e): pytest.xfail("Unsupported model/explainer combination") else: raise e - except NotImplementedError as e: - if ( - "The selected explainer class" in str(e) - and not use_comp_rep - and not isinstance(explainer_cls, shap.explainers.KernelExplainer) - ): - pytest.xfail("Exp. rep. not supported") - else: - raise e @mark.slow From 647ea729b83c5d674d3955a797b8d6900e909c47 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 10:36:31 +0100 Subject: [PATCH 75/92] Refactor explainer incompatibility handling using type validation --- baybe/exceptions.py | 4 ++++ baybe/insights/shap.py | 30 +++++++++++++----------------- tests/insights/test_shap.py | 31 +++++++++++++++---------------- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/baybe/exceptions.py b/baybe/exceptions.py index 75f6bd440..866db73a4 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -29,6 +29,10 @@ class IncompatibleAcquisitionFunctionError(IncompatibilityError): """An incompatible acquisition function was selected.""" +class IncompatibleExplainerError(IncompatibilityError): + """An explainer is incompatible with the data it is presented.""" + + class NotEnoughPointsLeftError(Exception): """ More recommendations are requested than there are viable parameter configurations diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 56f3ade3a..15e5d59ba 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -10,9 +10,11 @@ import pandas as pd from attrs import define, field from attrs.validators import instance_of +from shap import KernelExplainer from baybe import Campaign from baybe._optional.insights import shap +from baybe.exceptions import IncompatibleExplainerError from baybe.objectives.base import Objective from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace import SearchSpace @@ -74,6 +76,16 @@ def make_explainer_for_surrogate( if isinstance(explainer_cls, str): explainer_cls = _get_explainer_cls(explainer_cls) + if not ( + data.select_dtypes(exclude="number").empty + or isinstance(explainer_cls, KernelExplainer) + ): + raise IncompatibleExplainerError( + f"The selected explainer class '{explainer_cls.__name__}' does not support " + f"categorical data. Switch to computational representation or use " + f"'{KernelExplainer.__name__}'." + ) + import torch if use_comp_rep: @@ -95,23 +107,7 @@ def model(x: np.ndarray) -> np.ndarray: # Handle special settings: Lime default mode is otherwise set to "classification" kwargs = {"mode": "regression"} if explainer_cls.__name__ == "LimeTabular" else {} - try: - shap_explainer = explainer_cls(model, data, **kwargs) - - # Explain first two data points to ensure that the explainer is working - if is_shap_explainer(shap_explainer): - shap_explainer(data.iloc[0:1]) - except TypeError as e: - if "not supported for the input types" in str(e) and not use_comp_rep: - raise NotImplementedError( - f"The selected explainer class {explainer_cls} does not support " - f"the experimental representation. Switch to computational " - f"representation or use a different explainer (e.g. the default " - f"shap.KernelExplainer)." - ) - else: - raise e - return shap_explainer + return explainer_cls(model, data, **kwargs) @define diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 9ac618af8..f174667d2 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -7,6 +7,7 @@ from pytest import mark from baybe._optional.info import INSIGHTS_INSTALLED +from baybe.exceptions import IncompatibleExplainerError if not INSIGHTS_INSTALLED: pytest.skip("Optional insights package not installed.", allow_module_level=True) @@ -43,27 +44,25 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): """Helper function for general SHAP explainer tests.""" try: - # Sanity check explainer shap_insight = SHAPInsight.from_campaign( campaign, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) - assert isinstance(shap_insight, insights.SHAPInsight) - assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls)) - assert shap_insight.uses_shap_explainer == is_shap - - # Sanity check explanation - df = campaign.measurements[[p.name for p in campaign.parameters]] - if use_comp_rep: - df = campaign.searchspace.transform(df) - shap_explanation = shap_insight.explain(df) - assert isinstance(shap_explanation, shap.Explanation) - except NotImplementedError as e: - if "The selected explainer class" in str(e): - pytest.xfail("Unsupported model/explainer combination") - else: - raise e + except IncompatibleExplainerError: + pytest.xfail("Unsupported model/explainer combination.") + + # Sanity check explainer + assert isinstance(shap_insight, insights.SHAPInsight) + assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls)) + assert shap_insight.uses_shap_explainer == is_shap + + # Sanity check explanation + df = campaign.measurements[[p.name for p in campaign.parameters]] + if use_comp_rep: + df = campaign.searchspace.transform(df) + shap_explanation = shap_insight.explain(df) + assert isinstance(shap_explanation, shap.Explanation) @mark.slow From 5926496eb456d3299d01043fe551786d01d1ea7c Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 11:11:45 +0100 Subject: [PATCH 76/92] Move content of temporary test file by overriding fixtures locally --- tests/insights/test_shap.py | 75 ++++++++++++++++++++++++++++++------- tests/insights/test_tmp.py | 50 ------------------------- 2 files changed, 61 insertions(+), 64 deletions(-) delete mode 100644 tests/insights/test_tmp.py diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index f174667d2..06b519361 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -1,10 +1,13 @@ """Tests for insights subpackage.""" +import inspect from unittest import mock +import numpy as np import pandas as pd import pytest from pytest import mark +from shap import KernelExplainer from baybe._optional.info import INSIGHTS_INSTALLED from baybe.exceptions import IncompatibleExplainerError @@ -25,20 +28,45 @@ ) from tests.conftest import run_iterations -# File-wide parameterization settings -pytestmark = [ - mark.parametrize("n_grid_points", [5], ids=["g5"]), - mark.parametrize("n_iterations", [2], ids=["i2"]), - mark.parametrize("batch_size", [2], ids=["b2"]), - mark.parametrize( - "parameter_names", - [ - ["Conti_finite1", "Conti_finite2"], - ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], - ], - ids=["conti_params", "hybrid_params"], - ), -] + +@pytest.fixture +def n_grid_points(): + return 5 + + +@pytest.fixture +def n_iterations(): + return 2 + + +@pytest.fixture +def batch_size(): + return 2 + + +@pytest.fixture( + params=[ + ["Conti_finite1", "Conti_finite2"], + ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], + ], + ids=["conti_params", "hybrid_params"], +) +def parameter_names(request): + return request.param + + +def _has_required_init_parameters(cls: type[shap.Explainer]) -> bool: + """Check if non-shap initializer has required standard parameters.""" + REQUIRED_PARAMETERS = ["self", "model", "data"] + init_signature = inspect.signature(cls.__init__) + parameters = list(init_signature.parameters.keys()) + return parameters[:3] == REQUIRED_PARAMETERS + + +@pytest.mark.parametrize("explainer_name", NON_SHAP_EXPLAINERS) +def test_non_shap_signature(explainer_name): + """Non-SHAP explainers must have the required signature.""" + assert _has_required_init_parameters(_get_explainer_cls(explainer_name)) def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): @@ -147,3 +175,22 @@ def test_creation_from_recommender(ongoing_campaign): ongoing_campaign.measurements, ) assert isinstance(shap_insight, insights.SHAPInsight) + + +def test_column_permutation(): + """Explaining data with permuted columns gives permuted explanations.""" + N = 10 + + # Create insights object and test data + background_data = pd.DataFrame(np.random.random((N, 3)), columns=["x", "y", "z"]) + explainer = KernelExplainer(lambda x: x, background_data) + insights = SHAPInsight(explainer, background_data) + df = pd.DataFrame(np.random.random((N, 3)), columns=["x", "y", "z"]) + + # Regular column order + ex1 = insights.explain(df) + + # Permuted column order + ex2 = insights.explain(df[["z", "x", "y"]])[:, [1, 2, 0]] + + assert np.array_equal(ex1.values, ex2.values) diff --git a/tests/insights/test_tmp.py b/tests/insights/test_tmp.py deleted file mode 100644 index ba9e09c5a..000000000 --- a/tests/insights/test_tmp.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Temporary test file.""" - -import inspect - -import numpy as np -import pandas as pd -import pytest - -from baybe._optional.info import INSIGHTS_INSTALLED - -if not INSIGHTS_INSTALLED: - pytest.skip("Optional insights package not installed.", allow_module_level=True) - -import shap -from shap.explainers import KernelExplainer - -from baybe.insights.shap import NON_SHAP_EXPLAINERS, SHAPInsight, _get_explainer_cls - - -def _has_required_init_parameters(cls: type[shap.Explainer]) -> bool: - """Check if non-shap initializer has required standard parameters.""" - REQUIRED_PARAMETERS = ["self", "model", "data"] - init_signature = inspect.signature(cls.__init__) - parameters = list(init_signature.parameters.keys()) - return parameters[:3] == REQUIRED_PARAMETERS - - -@pytest.mark.parametrize("explainer_name", NON_SHAP_EXPLAINERS) -def test_non_shap_signature(explainer_name): - """Non-SHAP explainers must have the required signature.""" - assert _has_required_init_parameters(_get_explainer_cls(explainer_name)) - - -def test_column_permutation(): - """Explaining data with permuted columns gives permuted explanations.""" - N = 10 - - # Create insights object and test data - background_data = pd.DataFrame(np.random.random((N, 3)), columns=["x", "y", "z"]) - explainer = KernelExplainer(lambda x: x, background_data) - insights = SHAPInsight(explainer, background_data) - df = pd.DataFrame(np.random.random((N, 3)), columns=["x", "y", "z"]) - - # Regular column order - ex1 = insights.explain(df) - - # Permuted column order - ex2 = insights.explain(df[["z", "x", "y"]])[:, [1, 2, 0]] - - assert np.array_equal(ex1.values, ex2.values) From 432942f3816ff0dae743682383a6db6c75479274 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 11:16:04 +0100 Subject: [PATCH 77/92] Fix type check --- baybe/insights/shap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 15e5d59ba..960fcb7bd 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -78,7 +78,7 @@ def make_explainer_for_surrogate( if not ( data.select_dtypes(exclude="number").empty - or isinstance(explainer_cls, KernelExplainer) + or issubclass(explainer_cls, KernelExplainer) ): raise IncompatibleExplainerError( f"The selected explainer class '{explainer_cls.__name__}' does not support " From 06fac923140a5386f4d4fd90ebc1130dc791e66f Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 11:19:54 +0100 Subject: [PATCH 78/92] Replace xfail with skip --- tests/insights/test_shap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 06b519361..417871d5a 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -78,7 +78,7 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): use_comp_rep=use_comp_rep, ) except IncompatibleExplainerError: - pytest.xfail("Unsupported model/explainer combination.") + pytest.skip("Unsupported model/explainer combination.") # Sanity check explainer assert isinstance(shap_insight, insights.SHAPInsight) From 6010004df2e76b98fd737754ab6f3fc6736f7d6d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 13:38:24 +0100 Subject: [PATCH 79/92] Refine permutation workaround --- baybe/insights/shap.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 960fcb7bd..ea526b2cd 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -277,11 +277,19 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation: else: explanations = self.explainer(df_aligned) - # Permute explanation object data according to input column order. - # Do not do this for the base_values as it can be a scalar. + # Permute explanation object data according to input column order + # (`base_values` can be a scalar or vector) + # TODO: https://github.com/shap/shap/issues/3958 idx = self.background_data.columns.get_indexer(data.columns) - for attr in ["values", "data"]: - setattr(explanations, attr, getattr(explanations, attr)[:, idx]) + for attr in ["values", "data", "base_values"]: + try: + setattr(explanations, attr, getattr(explanations, attr)[:, idx]) + except IndexError as ex: + if not ( + isinstance(explanations.base_values, float) + or explanations.base_values.shape[1] == 1 + ): + raise TypeError("Unexpected explanation format.") from ex explanations.feature_names = [explanations.feature_names[i] for i in idx] # Reduce dimensionality of explanations to 2D in case From 56d2b07065759de60103553fd250fe929c3e902a Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 13:55:32 +0100 Subject: [PATCH 80/92] Update lockfile --- .lockfiles/py310-dev.lock | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.lockfiles/py310-dev.lock b/.lockfiles/py310-dev.lock index 003b461b2..7d5e0c1e8 100644 --- a/.lockfiles/py310-dev.lock +++ b/.lockfiles/py310-dev.lock @@ -137,6 +137,7 @@ docutils==0.21.2 # myst-parser # pybtex-docutils # sphinx + # sphinx-paramlinks # sphinxcontrib-bibtex e3fp==1.2.5 # via scikit-fingerprints @@ -890,6 +891,7 @@ sphinx==8.1.3 # sphinx-autodoc-typehints # sphinx-basic-ng # sphinx-copybutton + # sphinx-paramlinks # sphinxcontrib-bibtex sphinx-autodoc-typehints==2.5.0 # via baybe (pyproject.toml) @@ -897,6 +899,8 @@ sphinx-basic-ng==1.0.0b2 # via furo sphinx-copybutton==0.5.2 # via baybe (pyproject.toml) +sphinx-paramlinks==0.6.0 + # via baybe (pyproject.toml) sphinxcontrib-applehelp==1.0.8 # via sphinx sphinxcontrib-bibtex==2.6.2 From cb86db4abb3ddd06f607ce47c43f4135c0d477ae Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 13:57:06 +0100 Subject: [PATCH 81/92] Fix sphinx references --- baybe/insights/shap.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index ea526b2cd..bd80882a9 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -173,8 +173,8 @@ def from_campaign( Args: campaign: A campaign holding a recommender using a surrogate model. - explainer_cls: See :func:`make_explainer_for_surrogate. - use_comp_rep: See :func:`make_explainer_for_surrogate. + explainer_cls: See :func:`make_explainer_for_surrogate`. + use_comp_rep: See :func:`make_explainer_for_surrogate`. Returns: The SHAP insight object. @@ -214,8 +214,8 @@ def from_recommender( searchspace: The searchspace for the recommender. objective: The objective for the recommender. measurements: The measurements for training the surrogate and the explainer. - explainer_cls: See :func:`make_explainer_for_surrogate. - use_comp_rep: See :func:`make_explainer_for_surrogate. + explainer_cls: See :func:`make_explainer_for_surrogate`. + use_comp_rep: See :func:`make_explainer_for_surrogate`. Returns: The SHAP insight object. From cc41592926e5df975121f38506783c718ff81470 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 14:01:35 +0100 Subject: [PATCH 82/92] Improve docstrings --- baybe/insights/shap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index bd80882a9..d6dbd0fdd 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -114,7 +114,7 @@ def model(x: np.ndarray) -> np.ndarray: class SHAPInsight: """Class for SHAP-based feature importance insights. - Also supports LIME and MAPLE explainers via the ``shap`` module. + Also supports LIME and MAPLE explainers via the ``shap`` package. """ explainer: shap.Explainer = field(validator=instance_of(shap.Explainer)) @@ -172,7 +172,7 @@ def from_campaign( Uses the measurements of the campaign as background data. Args: - campaign: A campaign holding a recommender using a surrogate model. + campaign: A campaign using a surrogate-based recommender. explainer_cls: See :func:`make_explainer_for_surrogate`. use_comp_rep: See :func:`make_explainer_for_surrogate`. From ba7202dea645057442167fd56844ec7e2a62e9bd Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 17:31:57 +0100 Subject: [PATCH 83/92] Reintroduce recommender validation guard clause --- baybe/insights/shap.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index d6dbd0fdd..b312465b5 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -16,7 +16,7 @@ from baybe._optional.insights import shap from baybe.exceptions import IncompatibleExplainerError from baybe.objectives.base import Objective -from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.recommenders.base import RecommenderProtocol from baybe.searchspace import SearchSpace from baybe.surrogates.base import Surrogate, SurrogateProtocol from baybe.utils.dataframe import to_tensor @@ -197,7 +197,7 @@ def from_campaign( @classmethod def from_recommender( cls, - recommender: BayesianRecommender, + recommender: RecommenderProtocol, searchspace: SearchSpace, objective: Objective, measurements: pd.DataFrame, @@ -219,7 +219,17 @@ def from_recommender( Returns: The SHAP insight object. + + Raises: + TypeError: If the recommender has no ``get_surrogate`` method. """ + if not hasattr(recommender, "get_surrogate"): + raise TypeError( + f"The provided recommender does not provide a surrogate model. " + f"'{cls.__name__}' needs a surrogate model and thus only works with " + f"model-based recommenders." + ) + surrogate_model = recommender.get_surrogate( searchspace, objective, measurements ) From 1077904e58aab01b3cf93ec43e01014a5e05829e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 17:36:26 +0100 Subject: [PATCH 84/92] Make use_comp_rep flag keyword-only --- baybe/insights/shap.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index b312465b5..a8ad1dfd5 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -52,6 +52,7 @@ def make_explainer_for_surrogate( surrogate: Surrogate, data: pd.DataFrame, explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, + *, use_comp_rep: bool = False, ) -> shap.Explainer: """Create a SHAP explainer for a given surrogate model. @@ -143,6 +144,7 @@ def from_surrogate( surrogate: SurrogateProtocol, data: pd.DataFrame, explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, + *, use_comp_rep: bool = False, ): """Create a SHAP insight from a surrogate. @@ -156,7 +158,7 @@ def from_surrogate( ) explainer = make_explainer_for_surrogate( - surrogate, data, explainer_cls, use_comp_rep + surrogate, data, explainer_cls, use_comp_rep=use_comp_rep ) return cls(explainer, data) @@ -165,6 +167,7 @@ def from_campaign( cls, campaign: Campaign, explainer_cls: type[shap.Explainer] | str = _DEFAULT_EXPLAINER_CLS, + *, use_comp_rep: bool = False, ) -> SHAPInsight: """Create a SHAP insight from a campaign. @@ -202,6 +205,7 @@ def from_recommender( objective: Objective, measurements: pd.DataFrame, explainer_cls: type[shap.Explainer] | str = "KernelExplainer", + *, use_comp_rep: bool = False, ) -> SHAPInsight: """Create a SHAP insight from a recommender. From 6a5a4d923709e3f3e4a43f559e0dc11c162fb803 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 17:55:18 +0100 Subject: [PATCH 85/92] Adjust batteries included text in README --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 37fa2ddd0..a0fdfe27b 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,11 @@ The **Bay**esian **B**ack **E**nd (**BayBE**) is a general-purpose toolbox for B of Experiments, focusing on additions that enable real-world experimental campaigns. ## ๐Ÿ”‹ Batteries Included -Besides functionality to perform a typical recommend-measure loop, BayBE's highlights are: -- โœจ Custom parameter encodings: Improve your campaign with domain knowledge +Besides its core functionality to perform a typical recommend-measure loop, BayBE +offers a range of โœจ**built‑in features**โœจ crucial for real-world use cases. +The following provides a non-comprehensive overview: + +- ๐Ÿ› ๏ธ Custom parameter encodings: Improve your campaign with domain knowledge - ๐Ÿงช Built-in chemical encodings: Improve your campaign with chemical knowledge - ๐ŸŽฏ Single and multiple targets with min, max and match objectives - ๐Ÿ” Built-in analysis tools: Gain insights into feature importance and model behavior From 0f8c6f3de7f19d42c15c2a4c6ec09e466e664bad Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 14 Jan 2025 18:23:35 +0100 Subject: [PATCH 86/92] Move import statement to avoid test fail --- tests/insights/test_shap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 417871d5a..6bdc6a526 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -7,7 +7,6 @@ import pandas as pd import pytest from pytest import mark -from shap import KernelExplainer from baybe._optional.info import INSIGHTS_INSTALLED from baybe.exceptions import IncompatibleExplainerError @@ -23,6 +22,7 @@ NON_SHAP_EXPLAINERS, SHAP_EXPLAINERS, SHAP_PLOTS, + KernelExplainer, SHAPInsight, _get_explainer_cls, ) From 4b8842de2ad0727a9e12385d0a783ddaedc70b05 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 20 Jan 2025 12:37:41 +0100 Subject: [PATCH 87/92] Rephrase bullet point --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a0fdfe27b..7aa26c2d9 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ The following provides a non-comprehensive overview: - ๐Ÿ› ๏ธ Custom parameter encodings: Improve your campaign with domain knowledge - ๐Ÿงช Built-in chemical encodings: Improve your campaign with chemical knowledge - ๐ŸŽฏ Single and multiple targets with min, max and match objectives -- ๐Ÿ” Built-in analysis tools: Gain insights into feature importance and model behavior +- ๐Ÿ” Insights: Easily analyze feature importance and model behavior - ๐ŸŽญ Hybrid (mixed continuous and discrete) spaces - ๐Ÿš€ Transfer learning: Mix data from multiple campaigns and accelerate optimization - ๐ŸŽฐ Bandit models: Efficiently find the best among many options in noisy environments (e.g. A/B Testing) From b0879c27cef6460e7e0acac147822f42cc1bc696 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 20 Jan 2025 13:58:49 +0100 Subject: [PATCH 88/92] Enable force plot --- baybe/insights/shap.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index a8ad1dfd5..81b5b7a14 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -31,7 +31,7 @@ } NON_SHAP_EXPLAINERS = {"LimeTabular", "Maple"} EXPLAINERS = SHAP_EXPLAINERS | NON_SHAP_EXPLAINERS -SHAP_PLOTS = {"bar", "beeswarm", "heatmap", "scatter"} +SHAP_PLOTS = {"bar", "beeswarm", "force", "heatmap", "scatter"} def _get_explainer_cls(name: str) -> type[shap.Explainer]: @@ -325,6 +325,7 @@ def plot( /, *, show: bool = True, + explanation_idx: int | None = None, **kwargs: dict, ) -> plt.Axes: """Plot the Shapley values using the provided plot type. @@ -333,6 +334,9 @@ def plot( plot_type: The type of plot to be created. data: See :meth:`explain`. show: Boolean flag determining if the plot is to be rendered. + explanation_idx: Positional index of the data point that should be + explained. Only relevant for plot types that can only handle a single + data point. **kwargs: Additional keyword arguments passed to the plot function. Returns: @@ -355,7 +359,21 @@ def plot( ) plot_func = getattr(shap.plots, plot_type) - return plot_func(self.explain(data), show=show, **kwargs) + # Handle plot types that only explain a single data point + if plot_type in {"force"}: + if explanation_idx is None: + warnings.warn( + f"When using plot type '{plot_type}', a 'explanation_idx' must be " + f"chosen to identify a single data point that should be explained. " + f"Choosing the first measurement at position 0." + ) + explanation_idx = 0 + toplot = self.explain(data.iloc[[explanation_idx]]) + kwargs["matplotlib"] = True + else: + toplot = self.explain(data) + + return plot_func(toplot, show=show, **kwargs) def _plot_shap_scatter( self, data: pd.DataFrame | None = None, /, *, show: bool = True, **kwargs: dict From 2a8623bdbb8c6c92582dd40825e75141329b5862 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 20 Jan 2025 14:06:56 +0100 Subject: [PATCH 89/92] Fix kwargs type hints --- baybe/insights/shap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 81b5b7a14..63b30987a 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings -from typing import Literal +from typing import Any, Literal import matplotlib.pyplot as plt import numpy as np @@ -326,7 +326,7 @@ def plot( *, show: bool = True, explanation_idx: int | None = None, - **kwargs: dict, + **kwargs: Any, ) -> plt.Axes: """Plot the Shapley values using the provided plot type. @@ -376,7 +376,7 @@ def plot( return plot_func(toplot, show=show, **kwargs) def _plot_shap_scatter( - self, data: pd.DataFrame | None = None, /, *, show: bool = True, **kwargs: dict + self, data: pd.DataFrame | None = None, /, *, show: bool = True, **kwargs: Any ) -> plt.Axes: """Plot the Shapley values as scatter plot, ignoring non-numeric features. From 8a27738230d69cd7ac97e79726bc0d9ee2eaff72 Mon Sep 17 00:00:00 2001 From: Martin Fitzner <17951239+Scienfitz@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:04:19 +0100 Subject: [PATCH 90/92] Apply suggestions from code review Include improvements Co-authored-by: AdrianSosic --- baybe/insights/shap.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 63b30987a..469340344 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -325,7 +325,7 @@ def plot( /, *, show: bool = True, - explanation_idx: int | None = None, + explanation_index: int | None = None, **kwargs: Any, ) -> plt.Axes: """Plot the Shapley values using the provided plot type. @@ -334,7 +334,7 @@ def plot( plot_type: The type of plot to be created. data: See :meth:`explain`. show: Boolean flag determining if the plot is to be rendered. - explanation_idx: Positional index of the data point that should be + explanation_index: Positional index of the data point that should be explained. Only relevant for plot types that can only handle a single data point. **kwargs: Additional keyword arguments passed to the plot function. @@ -360,15 +360,15 @@ def plot( plot_func = getattr(shap.plots, plot_type) # Handle plot types that only explain a single data point - if plot_type in {"force"}: - if explanation_idx is None: + if plot_type == "force": + if explanation_index is None: warnings.warn( - f"When using plot type '{plot_type}', a 'explanation_idx' must be " - f"chosen to identify a single data point that should be explained. " - f"Choosing the first measurement at position 0." + f"When using plot type '{plot_type}', an 'explanation_index' must " + f"be chosen to identify a single data point that should be " + f"explained. Choosing the first entry at position 0." ) - explanation_idx = 0 - toplot = self.explain(data.iloc[[explanation_idx]]) + explanation_index = 0 + toplot = self.explain(data.iloc[[explanation_index]]) kwargs["matplotlib"] = True else: toplot = self.explain(data) From 0baf9b79daf91758587b171347b035d08d3a56ed Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 20 Jan 2025 18:18:50 +0100 Subject: [PATCH 91/92] Expand tests --- tests/insights/test_shap.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 6bdc6a526..0fc3eabef 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -47,9 +47,18 @@ def batch_size(): @pytest.fixture( params=[ ["Conti_finite1", "Conti_finite2"], - ["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"], + ["Num_disc_1", "Fraction_1"], + ["Num_disc_1", "Conti_finite1"], + ["Num_disc_1", "Categorical_1"], + ["Conti_finite1", "Categorical_1"], + ], + ids=[ + "params_conti", + "params_disc_num", + "params_hybrid_num", + "params_disc_cat", + "params_hybrid_cat", ], - ids=["conti_params", "hybrid_params"], ) def parameter_names(request): return request.param From 8797dab25d6b7f06803e0851c5acc97d2eba2a3c Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Mon, 20 Jan 2025 18:35:25 +0100 Subject: [PATCH 92/92] Use context for expected failures --- tests/insights/test_shap.py | 38 +++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 0fc3eabef..4d231999b 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -1,6 +1,7 @@ """Tests for insights subpackage.""" import inspect +from contextlib import nullcontext from unittest import mock import numpy as np @@ -80,26 +81,35 @@ def test_non_shap_signature(explainer_name): def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): """Helper function for general SHAP explainer tests.""" - try: + context = nullcontext() + if ( + (not use_comp_rep) + and (explainer_cls != "KernelExplainer") + and any(not p.is_numerical for p in campaign.parameters) + ): + # We expect a validation error in case an explanation with an unsupported + # explainer type is attempted on a search space representation with + # non-numerical entries + context = pytest.raises(IncompatibleExplainerError) + + with context: shap_insight = SHAPInsight.from_campaign( campaign, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) - except IncompatibleExplainerError: - pytest.skip("Unsupported model/explainer combination.") - - # Sanity check explainer - assert isinstance(shap_insight, insights.SHAPInsight) - assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls)) - assert shap_insight.uses_shap_explainer == is_shap - # Sanity check explanation - df = campaign.measurements[[p.name for p in campaign.parameters]] - if use_comp_rep: - df = campaign.searchspace.transform(df) - shap_explanation = shap_insight.explain(df) - assert isinstance(shap_explanation, shap.Explanation) + # Sanity check explainer + assert isinstance(shap_insight, insights.SHAPInsight) + assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls)) + assert shap_insight.uses_shap_explainer == is_shap + + # Sanity check explanation + df = campaign.measurements[[p.name for p in campaign.parameters]] + if use_comp_rep: + df = campaign.searchspace.transform(df) + shap_explanation = shap_insight.explain(df) + assert isinstance(shap_explanation, shap.Explanation) @mark.slow