To use this repository, it is necessary to know PyTorch and PyTorch Lightning. Also, we recommend you to use one of the logging framework that offered by
Pytorch Lightning, for example, Weights&Biases or Neptune.
This repository is an official implementation for Towards More Robust Interpretation via Local Gradient Alignment.
- Supported training losses: CE loss, Hessian regularizer, and l2+cosd regularizer
- Supported dataset: CIFAR10 and ImageNet100
- Supported model:
LeNetandResNet18
We highly recommend you to use our conda environment.
# clone project
git clone https://github.com/joshua840/RobustAttributionGradAlignment.git
# install project
cd RobustAGA
conda env create -f environment.yml
conda activate agpu_envOur directory structure looks like this:
βββ project
β βββ module <- Every modules are given in this directory
β β βββ lrp_module <- Modules to get LRP XAI are in this directory
β β βββ models <- Models
β β βββ utils <- utilized
β β βββ pl_classifier.py <- basic classifier
β β βββ pl_hessian_classifier.py <- hessian regularization
β β βββ pl_l2_plus_cosd_classifier.py <- l2 + cosd regularization
β β βββ test_adv_insertion.py <- run Adv-Insertion test
β β βββ test_insertion.py <- run Insertion test
β β βββ test_rps.py <- run Random Perturbation Similarity test
β β βββ test_taps_saps.py <- run adversarial attack test
β β βββ test_upper_bouond.py <- run upper bound test
β β
β βββ main.py <- run train & test
β βββ test_main.py <- run advanced test codes
β
βββ scripts <- Shell scripts
β
βββ .gitignore <- List of files/folders ignored by git
βββ environment.yml <- anaconda environment
βββ README.md
You can check the arguments list by typing -h on CLI.
python project/main.py -husage: main.py [-h] [--seed SEED] [--regularizer REGULARIZER] [--loggername LOGGERNAME] [--project PROJECT] [--dataset DATASET] [--model MODEL]
[--activation_fn ACTIVATION_FN] [--softplus_beta SOFTPLUS_BETA] [--optimizer OPTIMIZER] [--weight_decay WEIGHT_DECAY]
[--learning_rate LEARNING_RATE] [--milestones MILESTONES [MILESTONES ...]] [--num_workers NUM_WORKERS] [--batch_size_train BATCH_SIZE_TRAIN]
[--batch_size_test BATCH_SIZE_TEST] [--data_dir DATA_DIR]
optional arguments:
-h, --help show this help message and exit
--seed SEED random seeds (default: 1234)
--regularizer REGULARIZER
A regularizer to be used (default: none)
--loggername LOGGERNAME
a name of logger to be used (default: default)
--project PROJECT a name of project to be used (default: default)
--dataset DATASET dataset to be loaded (default: cifar10)
Default classifier:
--model MODEL which model to be used (default: none)
--activation_fn ACTIVATION_FN
activation function of model (default: relu)
--softplus_beta SOFTPLUS_BETA
beta of softplus (default: 20.0)
--optimizer OPTIMIZER
optimizer name (default: adam)
--weight_decay WEIGHT_DECAY
weight decay for optimizer (default: 4e-05)
--learning_rate LEARNING_RATE
learning rate for optimizer (default: 0.001)
--milestones MILESTONES [MILESTONES ...]
lr scheduler (default: [100, 150])
Data arguments:
--num_workers NUM_WORKERS
number of workers (default: 4)
--batch_size_train BATCH_SIZE_TRAIN
batchsize of data loaders (default: 128)
--batch_size_test BATCH_SIZE_TEST
batchsize of data loaders (default: 100)
--data_dir DATA_DIR directory of cifar10 dataset (default: /data/cifar10)
In our code, the trainer module is selected in here.
For each Lightning module class, we defined add_model_specific_args function, which requires additional arguments that used in that class.
By typing --regularizer option in CLI, you can also see these additional argument list.
python project/main.py --regularizer l2_cosd -h
l2_cosd arguments:
--eps EPS
--lamb LAMB
--lamb_c LAMB_C
--detach_source_grad DETACH_SOURCE_GRAD
python project/main.py --regularizer hessian -h
Hessian arguments:
--lamb LAMB
4.3 Hidden arguments
The Pytorch Lightning offers useful argument list for training. For example, we used --max_epochs and --default_root_dir in our experiments. We strongly recommend you to refer to the following link to check the argument list.
We offer three options of loggers.
- Tensorboard (https://www.tensorflow.org/tensorboard)
- Log & model checkpoints are saved in
--default_root_dir - Logging test code with Tensorboard is not available.
- Log & model checkpoints are saved in
- Weight & bias (https://wandb.ai/site)
- Create a new project on the WandB website.
- Specify the project argument
--project
- Neptune AI (https://neptune.ai/)
- Create a new project on the neptune website.
- export NEPTUNE_API_TOKEN="YOUR API TOKEN"
- export NEPTUNE_ID="YOUR ID"
- Set
--default_root_dirasoutput/YOUR_PROJECT_NAME
Likewise, You can check the options for test code.
python project/test_main.py --test_method aopc -h
python project/test_main.py --test_method adv -h
python project/test_main.py --test_method adv_aopc -h
python project/test_main.py --test_method rps -h
python project/test_main.py --test_method upper_bound -hFor those above test codes, you should specify the --exp_id argument. You can check the exp-id in your web project page and it seems like EXP1-1 for Neptune and 1skdq34 for WandB. Above runs will append the additional logs in to your projects.
This project is setup as a package which means you can now easily import any file into any other file like so:
from project.pl_classifier import LitClassifier
from project.module.utils.data_module import CIFAR10DataModule
from project.module.utils.data_module import ImageNet100DataModule
# Data
data_module = CIFAR10DataModule()
data_module = ImageNet100DataModule()
# Model
model = LitClassifier(model=model_name, activation_fn=activation_fn, softplus_beta=beta).cuda()
# train
trainer = Trainer()
trainer.fit(model, data_module)
# test using the best model!
trainer.test(model, data_module)This project is setup as a package which means you can now easily import any file into any other file like so:
from project.test_upper_bound import LitClassifierUpperBoundTester as LitClassifier
from project.module.utils.data_module import CIFAR10DataModule, ImageNet100DataModule
from project.module.utils.interpreter import Interpreter
# Data
data_module = CIFAR10DataModule(dataset='cifar10',batch_size_test=10,data_dir = '../data/cifar10')
data_module.prepare_data()
data_module.setup()
test_loader = data_module.test_dataloader()
x_batch, y_batch = next(iter(test_loader))
x_s = x_batch.cuda().requires_grad_()
y_s = y_batch.cuda()
# Model
ckpt_path = f'YOUR_CHECKPOINT_PATH'
ckpt = torch.load(model_path)
args = ckpt['hyper_parameters']
model = LitClassifier(**args).cuda()
model.load_state_dict(ckpt['state_dict'])
model.eval()
# Use interpreter
yhat_s = model(x_s)
h_s = Interpreter(model).get_heatmap(x_s, y_s, yhat_s, "grad", 'standard', 'abs', False).detach()@article{joo2022towards,
title={Towards More Robust Interpretation via Local Gradient Alignment},
author={Joo, Sunghwan and Jeong, Seokhyeon and Heo, Juyeon and Weller, Adrian and Moon, Taesup},
journal={arXiv preprint arXiv:2211.15900},
year={2022}
}
The citation of AAAI version is TBU.