Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CAPTUM and interpretability metrics #675

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ jobs:
name: Evaluator tests
when: always
command: |
coverage run -m unittest -v tests/test_captum.py
coverage run -m unittest -v tests/test_deployment.py
coverage run -m unittest -v tests/test_evaluation.py
coverage run -m unittest -v tests/validation_loop_test.py
coverage run -m unittest -v tests/test_handle_multi_crit.py

- run:
name: upload
when: always
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def evaluate_model(
idx += 1
eval_log[target + "_" + evaluation_metric.__class__.__name__] = s

# Explain model behaviour using shap
# Explain the model behaviour using shap
if "probabilistic" in inference_params:
print("Probabilistic explainability currently not supported.")
elif "n_targets" in model.params:
Expand Down
51 changes: 51 additions & 0 deletions flood_forecast/interpretability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
from typing import Tuple, Dict

attr_dict = {"IntegratedGradients": IntegratedGradients, "DeepLift": DeepLift, "GradientSHAP": GradientShap,
"NoiseTunnel": NoiseTunnel, "FeatureAblation": FeatureAblation}


def run_attribution(model, test_loader, method, additional_params: Dict) -> Tuple:
"""Function that creates attribution for a model based on Captum library.
:param model: The deep learning model to be used for attribution. This should be a PyTorch model.
:type model: torch.nn.Module
:param test_loader: Should be a FF CSVDataLoader or a related sub-class.
:type test_loader: _type_
:param method: _description_
:type method: _type_
:return: d
:rtype: Tuple

.. code-block:: python

from flood_forecast.interpretability import run_attribution
model = VanillaGRU(3, 128, 2, 1, 0.2)

..
"""

attribution_method = attr_dict[method](model)
x, y = test_loader[0]
the_data = attribution_method.attribute(x.unsqueeze(0), **additional_params)
if isinstance(the_data, tuple):
attributions, approximation_error = the_data
else:
attributions = the_data
approximation_error = None
return attributions, approximation_error


def make_attribution_plots(model, methods, use_wandb: bool = True):
"""Creates the attribution plots and logs them to wandb if use_wandb is True.

:param attributions: A tensor of the attributions should be of dimension (batch_size, , n_features).
:type attributions: torch.Tensor
:param approximation_error: _description_
:type approximation_error: _type_
:param use_wandb: _description_, defaults2 to True
:type use_wandb: bool, optional
"""
for method in methods:
attributions, approx = run_attribution(model.model, model.test_loader, methods, {})
# DO PLOTTING HERE
pass
8 changes: 4 additions & 4 deletions flood_forecast/transformer_xl/cross_former.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
"""Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting.
https://github.com/Thinklab-SJTU/Crossformer

:param n_time_series: The total number of time series
:param n_time_series: The total number of time series passed to the model
:type n_time_series: int
:param forecast_history: The length of the input sequence
:type forecast_history: int
Expand All @@ -36,11 +36,11 @@ def __init__(
:type win_size: int, optional
:param factor: _description_, defaults to 10
:type factor: int, optional
:param d_model: _description_, defaults to 512
:param d_model: _description_, sdefaults to 512
:type d_model: int, optional
:param d_ff: _description_, defaults to 1024
:type d_ff: int, optional
:param n_heads: _description_, defaults to 8
:param n_heads: The number of heads, defaults to 8
:type n_heads: int, optional
:param e_layers: _description_, defaults to 3
:type e_layers: int, optional
Expand Down Expand Up @@ -137,7 +137,7 @@ class SegMerging(nn.Module):
we set win_size = 2 in our paper
"""

def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
def __init__(self, d_model: int, win_size, norm_layer=nn.LayerNorm):
super().__init__()
self.d_model = d_model
self.win_size = win_size
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
shap==0.41.0
scikit-learn>=1.0.1
captum
pandas
torch
tb-nightly
Expand Down
2 changes: 1 addition & 1 deletion tests/multi_decoder_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"forecast_history":5,
"forecast_length":1,
"train_end": 100,
"valid_start":101,
"valid_start":102,
"valid_end": 201,
"test_start": 202,
"test_end": 290,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_captum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest
import torch
import os
from flood_forecast.interpretability import run_attribution, make_attribution_plots
from flood_forecast.basic.gru_vanilla import VanillaGRU
from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader


class TestCaptum(unittest.TestCase):
def setUp(self):
# n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float
self.test_model = VanillaGRU(3, 128, 2, 1, 0.2)
self.test_data_path = self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data")
self.test_data_loader = CSVDataLoader(
os.path.join(self.test_data_path, "keag_small.csv"),
100,
20,
"precip",
["precip", "cfs", "temp"]
)

def test_run_attribution(self):
"""_summary_"""
attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients",
{"return_convergence_delta": True})
self.assertEqual(approx_error.shape[0], 1)
self.assertIsInstance(attributions, torch.Tensor)
# self.assertEqual(attributions.shape[2], 3)

def test_create_attribution_plots(self):
"""_summary_"""
attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {})
make_attribution_plots(attributions, approx_error, use_wandb=False)