From da06f12f533e01041803ab412f3130af3784aeb1 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Fri, 19 May 2023 21:27:43 -0300 Subject: [PATCH 01/26] begin adding files --- flood_forecast/interpretability.py | 13 +++++++++++++ requirements.txt | 1 + 2 files changed, 14 insertions(+) create mode 100644 flood_forecast/interpretability.py diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py new file mode 100644 index 000000000..ac20cd94e --- /dev/null +++ b/flood_forecast/interpretability.py @@ -0,0 +1,13 @@ +from captum.attr import IntegratedGradients + + +def run_attribution(model, test_loader, method): + """ + """ + ig = IntegratedGradients(model) + x, y = test_loader[0] + attributions, approximation_error = ig.attribute((input1, input2), + baselines=(baseline1, baseline2), + method='gausslegendre', + return_convergence_delta=True) + pass diff --git a/requirements.txt b/requirements.txt index 292bc9c46..b25be39e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ shap==0.41.0 scikit-learn>=1.0.1 +captum pandas torch tb-nightly From c2d442e5d081f949111005d11712dc6858e498d5 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Fri, 19 May 2023 21:37:15 -0300 Subject: [PATCH 02/26] fixing code 2 --- flood_forecast/interpretability.py | 10 ++++++---- requirements.txt | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index ac20cd94e..fbbcf646b 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -1,13 +1,15 @@ from captum.attr import IntegratedGradients +from typing import Tuple +attr_dict = {} -def run_attribution(model, test_loader, method): +def run_attribution(model, test_loader, method) -> Tuple: """ """ ig = IntegratedGradients(model) x, y = test_loader[0] - attributions, approximation_error = ig.attribute((input1, input2), - baselines=(baseline1, baseline2), + attributions, approximation_error = ig.attribute(x, + baselines=(baseline1), method='gausslegendre', return_convergence_delta=True) - pass + return attributions, approximation_error diff --git a/requirements.txt b/requirements.txt index b25be39e7..176359afe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tb-nightly seaborn future h5py -wandb==0.15.1 +wandb==0.15.2 google-cloud google-cloud-storage plotly~=5.14.1 From 0f3ddffa4f72eb4c506bcdb7579c699e93307012 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 13:48:53 -0300 Subject: [PATCH 03/26] adding fixes --- flood_forecast/interpretability.py | 29 ++++++++++++------- flood_forecast/transformer_xl/cross_former.py | 2 +- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index fbbcf646b..96cc5f978 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -1,15 +1,24 @@ -from captum.attr import IntegratedGradients -from typing import Tuple +from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation +from typing import Tuple, Dict -attr_dict = {} +attr_dict = {"IntegratedGradients": IntegratedGradients, "DeepLift": DeepLift, "GradientSHAP": GradientShap, + "NoiseTunnel": NoiseTunnel, "FeatureAblation": FeatureAblation} -def run_attribution(model, test_loader, method) -> Tuple: - """ + +def run_attribution(model, test_loader, method, additional_params: Dict) -> Tuple: + """Function that creates attribution for a model based on Captum. + + :param model: The deep learning model to be used for attribution. This should be a PyTorch model. + :type model: _type_ + :param test_loader: _description_ + :type test_loader: _type_ + :param method: _description_ + :type method: _type_ + :return: d + :rtype: Tuple """ - ig = IntegratedGradients(model) + + attribution_method = attr_dict[method](model) x, y = test_loader[0] - attributions, approximation_error = ig.attribute(x, - baselines=(baseline1), - method='gausslegendre', - return_convergence_delta=True) + attributions, approximation_error = attribution_method.attribute(x.unsqueeze(0), **additional_params) return attributions, approximation_error diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index c7aa42d47..bba0923fd 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -36,7 +36,7 @@ def __init__( :type win_size: int, optional :param factor: _description_, defaults to 10 :type factor: int, optional - :param d_model: _description_, defaults to 512 + :param d_model: _description_, sdefaults to 512 :type d_model: int, optional :param d_ff: _description_, defaults to 1024 :type d_ff: int, optional From 726e07737865f3c27a4e4242a25158519dcbd69b Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 14:26:53 -0300 Subject: [PATCH 04/26] fixing captum 4 5 --- .circleci/config.yml | 1 + tests/test_captum.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 tests/test_captum.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 93fbda8c9..27a9c6ee9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -65,6 +65,7 @@ jobs: name: Evaluator tests when: always command: | + coverage run -m unitest -v tests/test_captum.py coverage run -m unittest -v tests/test_deployment.py coverage run -m unittest -v tests/test_evaluation.py coverage run -m unittest -v tests/validation_loop_test.py diff --git a/tests/test_captum.py b/tests/test_captum.py new file mode 100644 index 000000000..7c390385c --- /dev/null +++ b/tests/test_captum.py @@ -0,0 +1,24 @@ +import unittest +import torch +from flood_forecast.interpretability import run_attribution, create_attribution_plots +from flood_forecast.basic.gru_vanilla import GRUVanilla +from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader + + +class TestAttention(unittest.TestCase): + def setUp(self): + # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float + self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) + self.test_data_loader = CSVDataLoader( + "tests/data/test_data.csv", + 100, + 20, + "precip", + ["precip", "cfs", "temp"] + ) + + def test_run_attribution(self): + """_summary_""" + attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) + self.assertEqual(approx_error.shape, torch.Size([1, 20, 3])) + self.assertEqual(attributions.shape, torch.Size([1, 20, 3])) From 127633f90c878b2faae29d1dfef27407a2870201 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 14:45:49 -0300 Subject: [PATCH 05/26] stfu linter --- flood_forecast/interpretability.py | 13 +++++++++++++ tests/test_captum.py | 5 +++++ 2 files changed, 18 insertions(+) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index 96cc5f978..f7d6c015e 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -22,3 +22,16 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl x, y = test_loader[0] attributions, approximation_error = attribution_method.attribute(x.unsqueeze(0), **additional_params) return attributions, approximation_error + + +def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): + """_summary_ + + :param attributions: _description_ + :type attributions: _type_ + :param approximation_error: _description_ + :type approximation_error: _type_ + :param use_wandb: _description_, defaults to True + :type use_wandb: bool, optional + """ + pass diff --git a/tests/test_captum.py b/tests/test_captum.py index 7c390385c..8cd448e28 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -22,3 +22,8 @@ def test_run_attribution(self): attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) self.assertEqual(approx_error.shape, torch.Size([1, 20, 3])) self.assertEqual(attributions.shape, torch.Size([1, 20, 3])) + + def test_create_attribution_plots(self): + """_summary_""" + attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) + create_attribution_plots(attributions, approx_error, use_wandb=False) \ No newline at end of file From 5e1aebd79c160c60780b7ac8ec3c1e4f07b08601 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 14:56:43 -0300 Subject: [PATCH 06/26] fixing code 4.5 --- flood_forecast/interpretability.py | 2 +- tests/test_captum.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index f7d6c015e..cb42391f2 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -10,7 +10,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl :param model: The deep learning model to be used for attribution. This should be a PyTorch model. :type model: _type_ - :param test_loader: _description_ + :param test_loader: Should be a FF CSVDataLoader or a related subclass. :type test_loader: _type_ :param method: _description_ :type method: _type_ diff --git a/tests/test_captum.py b/tests/test_captum.py index 8cd448e28..e235a4f41 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -26,4 +26,4 @@ def test_run_attribution(self): def test_create_attribution_plots(self): """_summary_""" attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) - create_attribution_plots(attributions, approx_error, use_wandb=False) \ No newline at end of file + create_attribution_plots(attributions, approx_error, use_wandb=False) From d12e5c412707917fc8a80345e767ce8658a4103d Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 15:07:02 -0300 Subject: [PATCH 07/26] sleeplier than sleepy joe --- flood_forecast/transformer_xl/cross_former.py | 2 +- tests/test_captum.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index bba0923fd..adedfbc84 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -137,7 +137,7 @@ class SegMerging(nn.Module): we set win_size = 2 in our paper """ - def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): + def __init__(self, d_model: int, win_size, norm_layer=nn.LayerNorm): super().__init__() self.d_model = d_model self.win_size = win_size diff --git a/tests/test_captum.py b/tests/test_captum.py index e235a4f41..617048305 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -5,7 +5,7 @@ from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader -class TestAttention(unittest.TestCase): +class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) From c3d2707e4c2515eff82d669a3be4045c08f0bfde Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 15:07:14 -0300 Subject: [PATCH 08/26] r --- .circleci/config.yml | 2 +- flood_forecast/transformer_xl/cross_former.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 27a9c6ee9..181c8ce85 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -65,7 +65,7 @@ jobs: name: Evaluator tests when: always command: | - coverage run -m unitest -v tests/test_captum.py + coverage run -m unittest -v tests/test_captum.py coverage run -m unittest -v tests/test_deployment.py coverage run -m unittest -v tests/test_evaluation.py coverage run -m unittest -v tests/validation_loop_test.py diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index adedfbc84..f1b4f046f 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -144,7 +144,7 @@ def __init__(self, d_model: int, win_size, norm_layer=nn.LayerNorm): self.linear_trans = nn.Linear(win_size * d_model, d_model) self.norm = norm_layer(win_size * d_model) - def forward(self, x): + def forward(self, x: torch.Tensor): """ x: B, ts_d, L, d_model """ From c8370a852fcd7e72c3254903304efaf84060cf90 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 15:07:46 -0300 Subject: [PATCH 09/26] Revert "r" This reverts commit c3d2707e4c2515eff82d669a3be4045c08f0bfde. --- .circleci/config.yml | 2 +- flood_forecast/transformer_xl/cross_former.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 181c8ce85..27a9c6ee9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -65,7 +65,7 @@ jobs: name: Evaluator tests when: always command: | - coverage run -m unittest -v tests/test_captum.py + coverage run -m unitest -v tests/test_captum.py coverage run -m unittest -v tests/test_deployment.py coverage run -m unittest -v tests/test_evaluation.py coverage run -m unittest -v tests/validation_loop_test.py diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index f1b4f046f..adedfbc84 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -144,7 +144,7 @@ def __init__(self, d_model: int, win_size, norm_layer=nn.LayerNorm): self.linear_trans = nn.Linear(win_size * d_model, d_model) self.norm = norm_layer(win_size * d_model) - def forward(self, x: torch.Tensor): + def forward(self, x): """ x: B, ts_d, L, d_model """ From 770f20e208c4f7f81eb9b9f2d11d382aaa2a6cc5 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 15:39:56 -0300 Subject: [PATCH 10/26] refix config file 2 --- .circleci/config.yml | 3 +-- flood_forecast/transformer_xl/cross_former.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 27a9c6ee9..a1bb37fe3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -65,12 +65,11 @@ jobs: name: Evaluator tests when: always command: | - coverage run -m unitest -v tests/test_captum.py + coverage run -m unittest -v tests/test_captum.py coverage run -m unittest -v tests/test_deployment.py coverage run -m unittest -v tests/test_evaluation.py coverage run -m unittest -v tests/validation_loop_test.py coverage run -m unittest -v tests/test_handle_multi_crit.py - - run: name: upload when: always diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index adedfbc84..52c344283 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -40,7 +40,7 @@ def __init__( :type d_model: int, optional :param d_ff: _description_, defaults to 1024 :type d_ff: int, optional - :param n_heads: _description_, defaults to 8 + :param n_heads: The number of heads, defaults to 8 :type n_heads: int, optional :param e_layers: _description_, defaults to 3 :type e_layers: int, optional From c4bd4b86b3952dd66fa90ef09c803f575c1d7b79 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 15:49:04 -0300 Subject: [PATCH 11/26] remove unused function --- flood_forecast/evaluator.py | 2 +- tests/test_captum.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flood_forecast/evaluator.py b/flood_forecast/evaluator.py index 5f7e10a69..ad954c2de 100644 --- a/flood_forecast/evaluator.py +++ b/flood_forecast/evaluator.py @@ -175,7 +175,7 @@ def evaluate_model( idx += 1 eval_log[target + "_" + evaluation_metric.__class__.__name__] = s - # Explain model behaviour using shap + # Explain the model behaviour using shap if "probabilistic" in inference_params: print("Probabilistic explainability currently not supported.") elif "n_targets" in model.params: diff --git a/tests/test_captum.py b/tests/test_captum.py index 617048305..30ffb031e 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,6 +1,6 @@ import unittest import torch -from flood_forecast.interpretability import run_attribution, create_attribution_plots +from flood_forecast.interpretability import run_attribution, make_attribution_plots from flood_forecast.basic.gru_vanilla import GRUVanilla from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader @@ -26,4 +26,4 @@ def test_run_attribution(self): def test_create_attribution_plots(self): """_summary_""" attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) - create_attribution_plots(attributions, approx_error, use_wandb=False) + make_attribution_plots(attributions, approx_error, use_wandb=False) From bd8061f6e9dbc58b2ae120efab3f85a8e7f427a0 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 17:34:39 -0300 Subject: [PATCH 12/26] fixing captum tests --- flood_forecast/interpretability.py | 9 ++++++--- tests/test_captum.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index cb42391f2..72b7accf6 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -1,5 +1,6 @@ from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation from typing import Tuple, Dict +import numpy as np attr_dict = {"IntegratedGradients": IntegratedGradients, "DeepLift": DeepLift, "GradientSHAP": GradientShap, "NoiseTunnel": NoiseTunnel, "FeatureAblation": FeatureAblation} @@ -14,7 +15,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl :type test_loader: _type_ :param method: _description_ :type method: _type_ - :return: d + :return: Returns a Tuple of attributions and approximation error. This data is used to create plots. :rtype: Tuple """ @@ -24,7 +25,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl return attributions, approximation_error -def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): +def make_attribution_plots(x, attributions, approximation_error, feature_names, use_wandb: bool = True): """_summary_ :param attributions: _description_ @@ -34,4 +35,6 @@ def make_attribution_plots(attributions, approximation_error, use_wandb: bool = :param use_wandb: _description_, defaults to True :type use_wandb: bool, optional """ - pass + x_axis_data = np.arange(x.shape[1]) + print(x_axis_data) + print("Hello world") diff --git a/tests/test_captum.py b/tests/test_captum.py index 30ffb031e..692d2d98f 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,14 +1,14 @@ import unittest import torch from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import GRUVanilla +from flood_forecast.basic.gru_vanilla import VanillaGRU from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float - self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) + self.test_model = VanillaGRU(3, 128, 2, 1, 0.2) self.test_data_loader = CSVDataLoader( "tests/data/test_data.csv", 100, From ea3b2f1795b52785cb8b66dc4168bcccacbae511 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Wed, 24 May 2023 18:14:59 -0300 Subject: [PATCH 13/26] Revert "fixing captum tests" This reverts commit bd8061f6e9dbc58b2ae120efab3f85a8e7f427a0. --- flood_forecast/interpretability.py | 9 +++------ tests/test_captum.py | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index 72b7accf6..cb42391f2 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -1,6 +1,5 @@ from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation from typing import Tuple, Dict -import numpy as np attr_dict = {"IntegratedGradients": IntegratedGradients, "DeepLift": DeepLift, "GradientSHAP": GradientShap, "NoiseTunnel": NoiseTunnel, "FeatureAblation": FeatureAblation} @@ -15,7 +14,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl :type test_loader: _type_ :param method: _description_ :type method: _type_ - :return: Returns a Tuple of attributions and approximation error. This data is used to create plots. + :return: d :rtype: Tuple """ @@ -25,7 +24,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl return attributions, approximation_error -def make_attribution_plots(x, attributions, approximation_error, feature_names, use_wandb: bool = True): +def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): """_summary_ :param attributions: _description_ @@ -35,6 +34,4 @@ def make_attribution_plots(x, attributions, approximation_error, feature_names, :param use_wandb: _description_, defaults to True :type use_wandb: bool, optional """ - x_axis_data = np.arange(x.shape[1]) - print(x_axis_data) - print("Hello world") + pass diff --git a/tests/test_captum.py b/tests/test_captum.py index 692d2d98f..30ffb031e 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,14 +1,14 @@ import unittest import torch from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import VanillaGRU +from flood_forecast.basic.gru_vanilla import GRUVanilla from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float - self.test_model = VanillaGRU(3, 128, 2, 1, 0.2) + self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) self.test_data_loader = CSVDataLoader( "tests/data/test_data.csv", 100, From bd6ba3bae858aa295bd6a39a6919b4ef543c505b Mon Sep 17 00:00:00 2001 From: isaacmg Date: Tue, 30 May 2023 03:47:00 -0300 Subject: [PATCH 14/26] fixing unit tests of captum 2 --- flood_forecast/interpretability.py | 2 +- tests/test_captum.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index cb42391f2..2b31283f3 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -25,7 +25,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): - """_summary_ + """TODO implemnt :param attributions: _description_ :type attributions: _type_ diff --git a/tests/test_captum.py b/tests/test_captum.py index 30ffb031e..6f88c8f10 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,16 +1,18 @@ import unittest import torch +import os from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import GRUVanilla +from flood_forecast.basic.gru_vanilla import VanillaGRU from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float - self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) + self.test_model = VanillaGRU(3, 128, 2, 1, 0.2) + self.test_data_path = self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") self.test_data_loader = CSVDataLoader( - "tests/data/test_data.csv", + os.path.join(self.test_data_path, "keag_small.csv"), 100, 20, "precip", From a13120b2e1d0aa71c10d8caf521b35812477d304 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Tue, 30 May 2023 03:47:44 -0300 Subject: [PATCH 15/26] Revert "fixing unit tests of captum 2" This reverts commit bd6ba3bae858aa295bd6a39a6919b4ef543c505b. --- flood_forecast/interpretability.py | 2 +- tests/test_captum.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index 2b31283f3..cb42391f2 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -25,7 +25,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): - """TODO implemnt + """_summary_ :param attributions: _description_ :type attributions: _type_ diff --git a/tests/test_captum.py b/tests/test_captum.py index 6f88c8f10..30ffb031e 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,18 +1,16 @@ import unittest import torch -import os from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import VanillaGRU +from flood_forecast.basic.gru_vanilla import GRUVanilla from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float - self.test_model = VanillaGRU(3, 128, 2, 1, 0.2) - self.test_data_path = self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") + self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) self.test_data_loader = CSVDataLoader( - os.path.join(self.test_data_path, "keag_small.csv"), + "tests/data/test_data.csv", 100, 20, "precip", From 03b902e4d2efc2f11b5ddd4c1ac892eb4610a1d7 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 1 Jun 2023 17:14:15 -0300 Subject: [PATCH 16/26] add print debugging 1 --- flood_forecast/interpretability.py | 3 ++- tests/test_captum.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index cb42391f2..b5f6c1de8 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -10,7 +10,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl :param model: The deep learning model to be used for attribution. This should be a PyTorch model. :type model: _type_ - :param test_loader: Should be a FF CSVDataLoader or a related subclass. + :param test_loader: Should be a FF CSVDataLoader or a related sub-class. :type test_loader: _type_ :param method: _description_ :type method: _type_ @@ -20,6 +20,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl attribution_method = attr_dict[method](model) x, y = test_loader[0] + print(attribution_method.attribute(x.unsqueeze(0), **additional_params)) attributions, approximation_error = attribution_method.attribute(x.unsqueeze(0), **additional_params) return attributions, approximation_error diff --git a/tests/test_captum.py b/tests/test_captum.py index 30ffb031e..6f88c8f10 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,16 +1,18 @@ import unittest import torch +import os from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import GRUVanilla +from flood_forecast.basic.gru_vanilla import VanillaGRU from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float - self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) + self.test_model = VanillaGRU(3, 128, 2, 1, 0.2) + self.test_data_path = self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") self.test_data_loader = CSVDataLoader( - "tests/data/test_data.csv", + os.path.join(self.test_data_path, "keag_small.csv"), 100, 20, "precip", From 48925f58f0fc259555b7e79671021b678d1d946f Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 1 Jun 2023 17:14:54 -0300 Subject: [PATCH 17/26] Revert "add print debugging 1" This reverts commit 03b902e4d2efc2f11b5ddd4c1ac892eb4610a1d7. --- flood_forecast/interpretability.py | 3 +-- tests/test_captum.py | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index b5f6c1de8..cb42391f2 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -10,7 +10,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl :param model: The deep learning model to be used for attribution. This should be a PyTorch model. :type model: _type_ - :param test_loader: Should be a FF CSVDataLoader or a related sub-class. + :param test_loader: Should be a FF CSVDataLoader or a related subclass. :type test_loader: _type_ :param method: _description_ :type method: _type_ @@ -20,7 +20,6 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl attribution_method = attr_dict[method](model) x, y = test_loader[0] - print(attribution_method.attribute(x.unsqueeze(0), **additional_params)) attributions, approximation_error = attribution_method.attribute(x.unsqueeze(0), **additional_params) return attributions, approximation_error diff --git a/tests/test_captum.py b/tests/test_captum.py index 6f88c8f10..30ffb031e 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,18 +1,16 @@ import unittest import torch -import os from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import VanillaGRU +from flood_forecast.basic.gru_vanilla import GRUVanilla from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float - self.test_model = VanillaGRU(3, 128, 2, 1, 0.2) - self.test_data_path = self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") + self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) self.test_data_loader = CSVDataLoader( - os.path.join(self.test_data_path, "keag_small.csv"), + "tests/data/test_data.csv", 100, 20, "precip", From 25c5713ea3dff1c7d5688de852aca711621ebf98 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 1 Jun 2023 17:46:01 -0300 Subject: [PATCH 18/26] fixing code 3 --- flood_forecast/interpretability.py | 14 +++++++++++--- tests/test_captum.py | 3 ++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index cb42391f2..44f812d38 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -7,20 +7,28 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tuple: """Function that creates attribution for a model based on Captum. - :param model: The deep learning model to be used for attribution. This should be a PyTorch model. :type model: _type_ - :param test_loader: Should be a FF CSVDataLoader or a related subclass. + :param test_loader: Should be a FF CSVDataLoader or a related sub-class. :type test_loader: _type_ :param method: _description_ :type method: _type_ :return: d :rtype: Tuple + + .. code-block:: python + + .. """ attribution_method = attr_dict[method](model) x, y = test_loader[0] - attributions, approximation_error = attribution_method.attribute(x.unsqueeze(0), **additional_params) + the_data = attribution_method.attribute(x.unsqueeze(0), **additional_params) + if isinstance(the_data, tuple): + attributions, approximation_error = the_data + else: + attributions = the_data + approximation_error = None return attributions, approximation_error diff --git a/tests/test_captum.py b/tests/test_captum.py index 30ffb031e..82ee25e69 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -19,7 +19,8 @@ def setUp(self): def test_run_attribution(self): """_summary_""" - attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) + attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", + {"return_convergence_delta": True}) self.assertEqual(approx_error.shape, torch.Size([1, 20, 3])) self.assertEqual(attributions.shape, torch.Size([1, 20, 3])) From 2f4ad13aa1033a4f7e24f26ad153bf713beb4360 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 1 Jun 2023 18:12:38 -0300 Subject: [PATCH 19/26] adding code 2 --- flood_forecast/transformer_xl/cross_former.py | 2 +- tests/test_captum.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index 52c344283..8f67b0ccd 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -24,7 +24,7 @@ def __init__( """Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting. https://github.com/Thinklab-SJTU/Crossformer - :param n_time_series: The total number of time series + :param n_time_series: The total number of time series passed to the model :type n_time_series: int :param forecast_history: The length of the input sequence :type forecast_history: int diff --git a/tests/test_captum.py b/tests/test_captum.py index 82ee25e69..a9cd7a4e7 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,7 +1,7 @@ import unittest import torch from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import GRUVanilla +from flood_forecast.basic.gru_vanilla import VanillaGRU as GRUVanilla from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader From b3f66c2e370c8f103aa86f267dfec4caa0099fc8 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 1 Jun 2023 18:22:32 -0300 Subject: [PATCH 20/26] fixing code dwdgl,;msfdgbml,ksd --- flood_forecast/interpretability.py | 9 ++++++--- tests/test_captum.py | 8 +++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index 44f812d38..121aae4e9 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -6,9 +6,9 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tuple: - """Function that creates attribution for a model based on Captum. + """Function that creates attribution for a model based on Captum library. :param model: The deep learning model to be used for attribution. This should be a PyTorch model. - :type model: _type_ + :type model: torch.nn.Module :param test_loader: Should be a FF CSVDataLoader or a related sub-class. :type test_loader: _type_ :param method: _description_ @@ -18,6 +18,9 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl .. code-block:: python + from flood_forecast.interpretability import run_attribution + model = VanillaGRU(3, 128, 2, 1, 0.2) + .. """ @@ -33,7 +36,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): - """_summary_ + """Creates the attribution plots and logs them to wandb if use_wandb is True. :param attributions: _description_ :type attributions: _type_ diff --git a/tests/test_captum.py b/tests/test_captum.py index a9cd7a4e7..0cfe57b56 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -1,16 +1,18 @@ import unittest import torch +import os from flood_forecast.interpretability import run_attribution, make_attribution_plots -from flood_forecast.basic.gru_vanilla import VanillaGRU as GRUVanilla +from flood_forecast.basic.gru_vanilla import VanillaGRU from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader class TestCaptum(unittest.TestCase): def setUp(self): # n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float - self.test_model = GRUVanilla(3, 128, 2, 1, 0.2) + self.test_model = VanillaGRU(3, 128, 2, 1, 0.2) + self.test_data_path = self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") self.test_data_loader = CSVDataLoader( - "tests/data/test_data.csv", + os.path.join(self.test_data_path, "keag_small.csv"), 100, 20, "precip", From eaea42e45cffe00f6a3a9df3f74cfb95d10d1979 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Thu, 1 Jun 2023 20:27:04 -0300 Subject: [PATCH 21/26] r --- tests/multi_decoder_test.json | 2 +- tests/test_captum.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/multi_decoder_test.json b/tests/multi_decoder_test.json index 65d01cc78..26588dc39 100644 --- a/tests/multi_decoder_test.json +++ b/tests/multi_decoder_test.json @@ -18,7 +18,7 @@ "forecast_history":5, "forecast_length":1, "train_end": 100, - "valid_start":101, + "valid_start":102, "valid_end": 201, "test_start": 202, "test_end": 290, diff --git a/tests/test_captum.py b/tests/test_captum.py index 0cfe57b56..da25de9d5 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -23,7 +23,8 @@ def test_run_attribution(self): """_summary_""" attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {"return_convergence_delta": True}) - self.assertEqual(approx_error.shape, torch.Size([1, 20, 3])) + print(attributions.shape, approx_error.shape) + self.assertEqual(approx_error.shape[0], 1) self.assertEqual(attributions.shape, torch.Size([1, 20, 3])) def test_create_attribution_plots(self): From 0fdf9b5364a3e569e3ac89bf5050f4c33217a868 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Tue, 6 Jun 2023 15:28:16 -0300 Subject: [PATCH 22/26] fixng tests and more 2 --- flood_forecast/interpretability.py | 4 ++-- tests/test_captum.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index 121aae4e9..129f5a91e 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -38,8 +38,8 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): """Creates the attribution plots and logs them to wandb if use_wandb is True. - :param attributions: _description_ - :type attributions: _type_ + :param attributions: A tensor of the attributions should be of dimension (batch_size, , n_features). + :type attributions: torch.Tensor :param approximation_error: _description_ :type approximation_error: _type_ :param use_wandb: _description_, defaults to True diff --git a/tests/test_captum.py b/tests/test_captum.py index da25de9d5..7846be8b0 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -23,9 +23,9 @@ def test_run_attribution(self): """_summary_""" attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {"return_convergence_delta": True}) - print(attributions.shape, approx_error.shape) self.assertEqual(approx_error.shape[0], 1) - self.assertEqual(attributions.shape, torch.Size([1, 20, 3])) + self.assertIsInstance(attributions, torch.Tensor) + # self.assertEqual(attributions.shape[2], 3) def test_create_attribution_plots(self): """_summary_""" From 5574ef78394564b298e753deef838d7aea0c47b1 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Fri, 9 Jun 2023 22:08:46 -0400 Subject: [PATCH 23/26] r --- tests/test_captum.py | 3 +++ tests/test_dual.json | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_captum.py b/tests/test_captum.py index 7846be8b0..78e316283 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -4,6 +4,7 @@ from flood_forecast.interpretability import run_attribution, make_attribution_plots from flood_forecast.basic.gru_vanilla import VanillaGRU from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader +from captum.attr import IntegratedGradients class TestCaptum(unittest.TestCase): @@ -29,5 +30,7 @@ def test_run_attribution(self): def test_create_attribution_plots(self): """_summary_""" + attr = IntegratedGradients(self.test_model) + attr.attribute(self.test_data_loader) attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) make_attribution_plots(attributions, approx_error, use_wandb=False) diff --git a/tests/test_dual.json b/tests/test_dual.json index 850632927..a52b07afd 100644 --- a/tests/test_dual.json +++ b/tests/test_dual.json @@ -20,7 +20,7 @@ "train_end": 100, "valid_start":101, "valid_end": 201, - "test_start": 202, + "test_start": 232, "test_end": 290, "target_col": ["cfs"], "relevant_cols": ["cfs", "precip", "temp"], From b5e68c70bd2779516c58ee292fd60d3fd68a38e5 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Mon, 12 Jun 2023 01:38:57 -0400 Subject: [PATCH 24/26] Revert "r" This reverts commit 5574ef78394564b298e753deef838d7aea0c47b1. --- tests/test_captum.py | 3 --- tests/test_dual.json | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_captum.py b/tests/test_captum.py index 78e316283..7846be8b0 100644 --- a/tests/test_captum.py +++ b/tests/test_captum.py @@ -4,7 +4,6 @@ from flood_forecast.interpretability import run_attribution, make_attribution_plots from flood_forecast.basic.gru_vanilla import VanillaGRU from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader -from captum.attr import IntegratedGradients class TestCaptum(unittest.TestCase): @@ -30,7 +29,5 @@ def test_run_attribution(self): def test_create_attribution_plots(self): """_summary_""" - attr = IntegratedGradients(self.test_model) - attr.attribute(self.test_data_loader) attributions, approx_error = run_attribution(self.test_model, self.test_data_loader, "IntegratedGradients", {}) make_attribution_plots(attributions, approx_error, use_wandb=False) diff --git a/tests/test_dual.json b/tests/test_dual.json index a52b07afd..850632927 100644 --- a/tests/test_dual.json +++ b/tests/test_dual.json @@ -20,7 +20,7 @@ "train_end": 100, "valid_start":101, "valid_end": 201, - "test_start": 232, + "test_start": 202, "test_end": 290, "target_col": ["cfs"], "relevant_cols": ["cfs", "precip", "temp"], From 7c2d1f01429739a8fd102e112921c5ac73a1965d Mon Sep 17 00:00:00 2001 From: isaacmg Date: Sun, 25 Jun 2023 04:21:51 -0400 Subject: [PATCH 25/26] add basic architecting --- flood_forecast/interpretability.py | 11 ++++++++++- tests/da_meta.json | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index 129f5a91e..37b7f4f67 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -1,5 +1,6 @@ from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation from typing import Tuple, Dict +import numpy as np attr_dict = {"IntegratedGradients": IntegratedGradients, "DeepLift": DeepLift, "GradientSHAP": GradientShap, "NoiseTunnel": NoiseTunnel, "FeatureAblation": FeatureAblation} @@ -35,7 +36,7 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl return attributions, approximation_error -def make_attribution_plots(attributions, approximation_error, use_wandb: bool = True): +def make_attribution_plots(attributions, approximation_error, model, x, y, use_wandb: bool = True): """Creates the attribution plots and logs them to wandb if use_wandb is True. :param attributions: A tensor of the attributions should be of dimension (batch_size, , n_features). @@ -45,4 +46,12 @@ def make_attribution_plots(attributions, approximation_error, use_wandb: bool = :param use_wandb: _description_, defaults to True :type use_wandb: bool, optional """ + x_axis_data = np.arange(x.shape[1]) + x_axis_data_labels = model.params["fea"] + + ig_attr_test_sum = attributions.detach().numpy().sum(0) + ig_attr_test_norm_sum = ig_attr_test_sum / np.linalg.norm(ig_attr_test_sum, ord=1) + + lin_weight = model.lin1.weight[0].detach().numpy() + y_axis_lin_weight = lin_weight / np.linalg.norm(lin_weight, ord=1) pass diff --git a/tests/da_meta.json b/tests/da_meta.json index d7b0cca62..5f7748685 100644 --- a/tests/da_meta.json +++ b/tests/da_meta.json @@ -35,7 +35,7 @@ "forecast_history":5, "forecast_length":1, "train_end": 300, - "valid_start":302, + "valid_start":333, "valid_end": 404, "test_end": 500, "target_col": ["cfs"], From a9053aeba678aff9ae0e27f94899b22e53f101e7 Mon Sep 17 00:00:00 2001 From: isaacmg Date: Sun, 25 Jun 2023 04:40:21 -0400 Subject: [PATCH 26/26] fixing code 3 --- flood_forecast/interpretability.py | 16 +++++----------- tests/da_meta.json | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/flood_forecast/interpretability.py b/flood_forecast/interpretability.py index 37b7f4f67..57ae68226 100644 --- a/flood_forecast/interpretability.py +++ b/flood_forecast/interpretability.py @@ -1,6 +1,5 @@ from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation from typing import Tuple, Dict -import numpy as np attr_dict = {"IntegratedGradients": IntegratedGradients, "DeepLift": DeepLift, "GradientSHAP": GradientShap, "NoiseTunnel": NoiseTunnel, "FeatureAblation": FeatureAblation} @@ -36,22 +35,17 @@ def run_attribution(model, test_loader, method, additional_params: Dict) -> Tupl return attributions, approximation_error -def make_attribution_plots(attributions, approximation_error, model, x, y, use_wandb: bool = True): +def make_attribution_plots(model, methods, use_wandb: bool = True): """Creates the attribution plots and logs them to wandb if use_wandb is True. :param attributions: A tensor of the attributions should be of dimension (batch_size, , n_features). :type attributions: torch.Tensor :param approximation_error: _description_ :type approximation_error: _type_ - :param use_wandb: _description_, defaults to True + :param use_wandb: _description_, defaults2 to True :type use_wandb: bool, optional """ - x_axis_data = np.arange(x.shape[1]) - x_axis_data_labels = model.params["fea"] - - ig_attr_test_sum = attributions.detach().numpy().sum(0) - ig_attr_test_norm_sum = ig_attr_test_sum / np.linalg.norm(ig_attr_test_sum, ord=1) - - lin_weight = model.lin1.weight[0].detach().numpy() - y_axis_lin_weight = lin_weight / np.linalg.norm(lin_weight, ord=1) + for method in methods: + attributions, approx = run_attribution(model.model, model.test_loader, methods, {}) + # DO PLOTTING HERE pass diff --git a/tests/da_meta.json b/tests/da_meta.json index 5f7748685..d7b0cca62 100644 --- a/tests/da_meta.json +++ b/tests/da_meta.json @@ -35,7 +35,7 @@ "forecast_history":5, "forecast_length":1, "train_end": 300, - "valid_start":333, + "valid_start":302, "valid_end": 404, "test_end": 500, "target_col": ["cfs"],