Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insights module and SHAPInsight #391

Merged
merged 102 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
0c8e945
Optional import of shap package.
Alex6022 Sep 29, 2024
cbe2e82
1st implementation of SHAP utilities in experimental space and with p…
Alex6022 Sep 29, 2024
2597fd4
Implementation option to perform SHAP either in computational or expe…
Alex6022 Oct 1, 2024
bc1203e
SHAP package implementation in diagnostics utility, complete tests an…
Alex6022 Oct 3, 2024
b348b46
Tests for explainer utilities and generalization for all explainers i…
Alex6022 Oct 3, 2024
ae20322
Implemented plotting with non-shap attributions.
Alex6022 Oct 3, 2024
de9d1e9
Refactored diagnostics test and optimized handling of maple explainers.
Alex6022 Oct 4, 2024
e183957
Shortened plotting method names.
Alex6022 Oct 4, 2024
85fb9ba
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 Oct 4, 2024
c389ac1
Cleanup for PR
Alex6022 Oct 4, 2024
55e723c
Renamed diangostics package, enabled optional shap import
Alex6022 Oct 23, 2024
1922467
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 Oct 23, 2024
ee57008
Refactoring of test_diagnostics.py
Alex6022 Oct 27, 2024
11b61d1
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 Oct 28, 2024
08a4c1e
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 Oct 28, 2024
50846f7
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 Oct 28, 2024
103a5f7
Fixed changelog merging error
Alex6022 Oct 28, 2024
ffda991
Update pyproject.toml
Scienfitz Nov 1, 2024
eaa5c38
Rework import flag
Scienfitz Nov 1, 2024
4ca9ffd
Update mypy.ini
Scienfitz Nov 1, 2024
9fddbdd
Rework tests
Scienfitz Nov 1, 2024
34e44bc
Generalized explanation for all shap explainer types, refactoring of …
Alex6022 Nov 18, 2024
9c94ccd
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 Nov 18, 2024
7823e87
Reworked tests from feedback, Cleanup for Review
Alex6022 Nov 19, 2024
0c2deb6
Further cleanup
Alex6022 Nov 20, 2024
436c08f
Renaming of "diagnostics" package into "insights", Addition of Insigh…
Alex6022 Dec 9, 2024
3fef7b3
Moved explainer maps from testing of shap functionality to shap, allo…
Alex6022 Dec 10, 2024
b50cd24
Merge branch 'main' of https://github.com/emdgroup/baybe into feature…
Alex6022 Dec 10, 2024
949629f
Cleanup of CONTRIBUTORS.md
Alex6022 Dec 10, 2024
36cb30c
Minor reformatting
Scienfitz Dec 27, 2024
67fdf8d
Package housekeeping
Scienfitz Dec 27, 2024
63da2b7
Merge branch 'main' into feature/shap-utils
Scienfitz Dec 27, 2024
69cecd8
Rework classes
Scienfitz Jan 2, 2025
a6eef09
Update tests
Scienfitz Jan 2, 2025
b93a376
Enhance error message
Scienfitz Jan 3, 2025
271645f
Add special handling for Lime
Scienfitz Jan 3, 2025
5997095
Rename bg_data to background_data
AdrianSosic Jan 6, 2025
ec850ba
Add missing validators
AdrianSosic Jan 6, 2025
ae33f55
Fix type annotations
AdrianSosic Jan 6, 2025
c1d69a6
Replace unnecessary post init call
AdrianSosic Jan 6, 2025
e58c04a
Remove unnecessary guard clause
AdrianSosic Jan 6, 2025
f5530bc
Turn method guard clause into proper attribute validator
AdrianSosic Jan 6, 2025
313f394
Improve model function definition
AdrianSosic Jan 6, 2025
e3bb895
Refactor plot methods to return plt.Axes if plots not shown directly.…
Alex6022 Jan 7, 2025
ede033f
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 Jan 7, 2025
ff85048
Drop Insights base class
AdrianSosic Jan 7, 2025
dd83e99
Turn explanation property into explain method
AdrianSosic Jan 7, 2025
740a4a1
Pass data to be explained as method argument
AdrianSosic Jan 7, 2025
09ee6ac
Extract explainer factory function
AdrianSosic Jan 7, 2025
c8940c0
Add from_surrogate constructor
AdrianSosic Jan 7, 2025
86c99d3
Refactor class attributes
AdrianSosic Jan 7, 2025
52db71b
Drop duplicate input validation
AdrianSosic Jan 7, 2025
b03d3bf
Refactor plotting
AdrianSosic Jan 8, 2025
d4aeaf4
Define default explainer class
AdrianSosic Jan 8, 2025
331746d
Make data to be explained optional
AdrianSosic Jan 8, 2025
2dc0e94
Drop converter utility
AdrianSosic Jan 8, 2025
31d95f4
Refactor explainer sets
AdrianSosic Jan 9, 2025
84b5f18
Add column permutation test
AdrianSosic Jan 9, 2025
151207b
Make test pass by permuting columns
AdrianSosic Jan 9, 2025
8750a0f
Rework docstrings
AdrianSosic Jan 9, 2025
b748a57
Avoid unnecessary data copy
AdrianSosic Jan 9, 2025
6e81c97
Fix optional dependency handling
AdrianSosic Jan 9, 2025
4ab5898
Use shap's optional dependency group
AdrianSosic Jan 9, 2025
c819d97
Update lockfile
AdrianSosic Jan 9, 2025
3296933
Update README.md
AdrianSosic Jan 9, 2025
9650397
Refactor insights/__init__.py
AdrianSosic Jan 9, 2025
1d38075
Update README.md
AdrianSosic Jan 9, 2025
bec5fc5
Refactor insights/__init__.py
AdrianSosic Jan 9, 2025
2f5c96e
Add missing sphinx-paramlinks doc dependency
AdrianSosic Jan 9, 2025
b65731b
Update lockfile
AdrianSosic Jan 9, 2025
118de87
Fix BayBE spelling
AdrianSosic Jan 9, 2025
a1b058c
fixup! Rework docstrings
AdrianSosic Jan 10, 2025
11e01f4
Fix mypy issues
AdrianSosic Jan 10, 2025
905e718
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 Jan 12, 2025
a9a0dd1
Fixed permutation of explanation object. Reintroduced ValueError when…
Alex6022 Jan 12, 2025
c6421c5
Simplify set union
AdrianSosic Jan 13, 2025
a447c42
Rename df method argument to data
AdrianSosic Jan 13, 2025
b44f47f
Removed double docstring
Alex6022 Jan 13, 2025
03c95f7
Test cleanup
Alex6022 Jan 13, 2025
e075b9e
Improved comments and docstring
Alex6022 Jan 13, 2025
bf9cf2b
Filter data to measurement parameters only when initializing from rec…
Alex6022 Jan 13, 2025
f89f9ef
Validate that explainer is of accepted type
AdrianSosic Jan 14, 2025
8282b9a
Avoid test collection import errors due to optional dependencies
AdrianSosic Jan 14, 2025
856d4fd
Drop try-except branch for unsupported explainers types
AdrianSosic Jan 14, 2025
647ea72
Refactor explainer incompatibility handling using type validation
AdrianSosic Jan 14, 2025
5926496
Move content of temporary test file by overriding fixtures locally
AdrianSosic Jan 14, 2025
432942f
Fix type check
AdrianSosic Jan 14, 2025
06fac92
Replace xfail with skip
AdrianSosic Jan 14, 2025
6010004
Refine permutation workaround
AdrianSosic Jan 14, 2025
56d2b07
Update lockfile
AdrianSosic Jan 14, 2025
cb86db4
Fix sphinx references
AdrianSosic Jan 14, 2025
cc41592
Improve docstrings
AdrianSosic Jan 14, 2025
ba7202d
Reintroduce recommender validation guard clause
AdrianSosic Jan 14, 2025
1077904
Make use_comp_rep flag keyword-only
AdrianSosic Jan 14, 2025
6a5a4d9
Adjust batteries included text in README
AdrianSosic Jan 14, 2025
0f8c6f3
Move import statement to avoid test fail
AdrianSosic Jan 14, 2025
4b8842d
Rephrase bullet point
Scienfitz Jan 20, 2025
b0879c2
Enable force plot
Scienfitz Jan 20, 2025
2a8623b
Fix kwargs type hints
Scienfitz Jan 20, 2025
8a27738
Apply suggestions from code review
Scienfitz Jan 20, 2025
0baf9b7
Expand tests
Scienfitz Jan 20, 2025
8797dab
Use context for expected failures
Scienfitz Jan 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- Added SHAP analysis within the new `diagnostics` package.
- `allow_missing` and `allow_extra` keyword arguments to `Objective.transform`

### Deprecations
Expand Down
4 changes: 3 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
- Di Jin (Merck Life Science KGaA, Darmstadt, Germany):\
Cardinality constraints
- Julian Streibel (Merck Life Science KGaA, Darmstadt, Germany):\
Bernoulli multi-armed bandit and Thompson sampling
Bernoulli multi-armed bandit and Thompson sampling
- Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dübendorf, Switzerland):\
SHAP explainers for diagnoatics
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ The available groups are:
- `mypy`: Required for static type checking.
- `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai).
- `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/)
- `diagnostics`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/)
- `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module.
- `test`: Required for running the tests.
- `dev`: All of the above plus `tox` and `pip-audit`. For code contributors.
Expand Down
16 changes: 16 additions & 0 deletions baybe/_optional/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Optional import for diagnostics utilities."""

from baybe.exceptions import OptionalImportError

try:
import shap
except ModuleNotFoundError as ex:
raise OptionalImportError(
"Explainer functionality is unavailable because 'diagnostics' is not installed."
" Consider installing BayBE with 'diagnostics' dependency, e.g. via "
"`pip install baybe[diagnostics]`."
) from ex

__all__ = [
"shap",
]
2 changes: 2 additions & 0 deletions baybe/_optional/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404
MORDRED_INSTALLED = find_spec("mordred") is not None
ONNX_INSTALLED = find_spec("onnxruntime") is not None
POLARS_INSTALLED = find_spec("polars") is not None
SHAP_INSTALLED = find_spec("shap") is not None
PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None
PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None
RDKIT_INSTALLED = find_spec("rdkit") is not None
Expand All @@ -45,6 +46,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404

# Package combinations
CHEM_INSTALLED = MORDRED_INSTALLED and RDKIT_INSTALLED
DIAGNOSTICS_INSTALLED = SHAP_INSTALLED
LINT_INSTALLED = all(
(
FLAKE8_INSTALLED,
Expand Down
194 changes: 194 additions & 0 deletions baybe/utils/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Diagnostics utilities."""

import numbers
import warnings

import numpy as np
import pandas as pd

from baybe import Campaign
from baybe._optional.diagnostics import shap
from baybe.utils.dataframe import to_tensor


def explainer(
campaign: Campaign,
explainer_class: shap.Explainer = shap.KernelExplainer,
computational_representation: bool = False,
**kwargs,
) -> shap.Explainer:
"""Create an explainer for the provided campaign.

Args:
campaign: The campaign to be explained.
explainer_class: The explainer to be used. Default is shap.KernelExplainer.
computational_representation: Whether to compute the Shapley values
in computational or experimental searchspace.
Default is False.
**kwargs: Additional keyword arguments to be passed to the explainer.

Returns:
The explainer for the provided campaign.

Raises:
ValueError: If no measurements have been provided yet.
"""
if campaign.measurements.empty:
raise ValueError("No measurements have been provided yet.")

data = campaign.measurements[[p.name for p in campaign.parameters]].copy()

if computational_representation:
data = campaign.searchspace.transform(data)

def model(x):
tensor = to_tensor(x)
output = campaign.get_surrogate()._posterior_comp(tensor).mean

return output.detach().numpy()
else:

def model(x):
df = pd.DataFrame(x, columns=data.columns)
output = campaign.get_surrogate().posterior(df).mean

return output.detach().numpy()

shap_explainer = explainer_class(model, data, **kwargs)
return shap_explainer


def explanation(
campaign: Campaign,
data: np.ndarray = None,
explainer_class: shap.Explainer = shap.KernelExplainer,
computational_representation: bool = False,
**kwargs,
) -> shap.Explanation:
"""Compute the Shapley values for the provided campaign and data.

Args:
campaign: The campaign to be explained.
data: The data to be explained.
Default is None which uses the campaign's measurements.
explainer_class: The explainer to be used.
Default is shap.KernelExplainer.
computational_representation: Whether to compute the Shapley values
in computational or experimental searchspace.
Default is False.
**kwargs: Additional keyword arguments to be passed to the explainer.

Returns:
The Shapley values for the provided campaign.

Raises:
ValueError: If the provided data does not have the same amount of parameters
as previously provided to the explainer.
"""
is_shap_explainer = not explainer_class.__module__.startswith(
"shap.explainers.other."
)

if not is_shap_explainer and not computational_representation:
raise ValueError(
"Experimental representation is not "
"supported for non-Kernel SHAP explainer."
)

explainer_obj = explainer(
campaign,
explainer_class=explainer_class,
computational_representation=computational_representation,
**kwargs,
)

if data is None:
if isinstance(explainer_obj.data, np.ndarray):
data = explainer_obj.data
else:
data = explainer_obj.data.data
elif computational_representation:
data = campaign.searchspace.transform(data)

if not is_shap_explainer:
"""Return attributions for non-SHAP explainers."""
if explainer_class.__module__.endswith("maple"):
"""Additional argument for maple to increase comparability to SHAP."""
attributions = explainer_obj.attributions(data, multiply_by_input=True)[0]
else:
attributions = explainer_obj.attributions(data)[0]
if computational_representation:
feature_names = campaign.searchspace.comp_rep_columns
else:
feature_names = campaign.searchspace.parameter_names
explanations = shap.Explanation(
values=attributions,
base_values=data,
data=data,
)
explanations.feature_names = list(feature_names)
return explanations

if data.shape[1] != explainer_obj.data.data.shape[1]:
raise ValueError(
"The provided data does not have the same amount "
"of parameters as the shap explainer background."
)
else:
shap_explanations = explainer_obj(data)[:, :, 0]

return shap_explanations


def plot_beeswarm(explanation: shap.Explanation, **kwargs) -> None:
"""Plot the Shapley values using a beeswarm plot."""
shap.plots.beeswarm(explanation, **kwargs)


def plot_waterfall(explanation: shap.Explanation, **kwargs) -> None:
"""Plot the Shapley values using a waterfall plot."""
shap.plots.waterfall(explanation, **kwargs)


def plot_bar(explanation: shap.Explanation, **kwargs) -> None:
"""Plot the Shapley values using a bar plot."""
shap.plots.bar(explanation, **kwargs)


def plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None:
"""Plot the Shapley values using a scatter plot while leaving out string values.

Args:
explanation: The Shapley values to be plotted.
**kwargs: Additional keyword arguments to be passed to the scatter plot.

Raises:
ValueError: If the provided explanation object does not match the
required types.
"""
if isinstance(explanation, memoryview):
data = explanation.obj
elif isinstance(explanation, shap.Explanation):
data = explanation.data.data.obj
else:
raise ValueError("The provided explanation argument is not of a valid type.")

def is_not_numeric_column(col):
return np.array([not isinstance(v, numbers.Number) for v in col]).any()

if data.ndim == 1:
if is_not_numeric_column(data):
warnings.warn(
"Cannot plot scatter plot for the provided "
"explanation as it contains non-numeric values."
)
else:
shap.plots.scatter(explanation, **kwargs)
else:
number_enum = [i for i, x in enumerate(data[1]) if not isinstance(x, str)]
if len(number_enum) < len(explanation.feature_names):
warnings.warn(
"Cannot plot SHAP scatter plot for all "
"parameters as some contain non-numeric values."
)
shap.plots.scatter(explanation[:, number_enum], **kwargs)
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ ignore_missing_imports = True

[mypy-polars]
ignore_missing_imports = True

[mypy-shap.*]
ignore_missing_imports = True
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ onnx = [

dev = [
"baybe[chem]",
"baybe[diagnostics]",
"baybe[docs]",
"baybe[examples]",
"baybe[lint]",
Expand All @@ -94,6 +95,11 @@ dev = [
"uv>=0.3.0", # `uv lock` (for lockfiles) is stable since 0.3.0: https://github.com/astral-sh/uv/issues/2679#event-13950215962
]

diagnostics = [
"shap>=0.46.0",
"lime>=0.2.0.1"
]

docs = [
"baybe[examples]", # docs cannot be built without running examples
"furo>=2023.09.10",
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ addopts =
--ignore=baybe/_optional
--ignore=baybe/utils/chemistry.py
--ignore=tests/simulate_telemetry.py
--ignore=baybe/utils/diagnostics.py
--ignore=tests/utils/test_diagnostics.py
testpaths =
baybe
tests
Loading
Loading