-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathREADME.Rmd
140 lines (100 loc) · 4.55 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
fig.height = 3,
fig.width = 8
)
options(width = 200)
set.seed(12345)
```
# counterfactuals
<!-- badges: start -->
[](https://github.com/dandls/counterfactuals/actions)
<!-- badges: end -->
The `counterfactuals` package provides various (model-agnostic) counterfactual explanation methods via a unified R6-based interface.
Counterfactual explanation methods address questions of the form:
"For input $\mathbf{x^{\star}}$, the model predicted $y$. What needs to be changed in $\mathbf{x^{\star}}$ for the model
to predict a desired outcome $\tilde{y}$ instead?". \
Denied loan applications serve as a common example; here a counterfactual explanation (or counterfactual for short) could be:
"The loan was denied because the amount of €30k is too high
given the income. If the amount had been €20k, the loan would have been granted."
For an introduction to counterfactual explanation methods, we recommend Chapter 9.3 of the [Interpretable Machine Learning book](https://christophm.github.io/interpretable-ml-book/) by Christoph Molnar. The package is based on the R code underlying the paper [Multi-Objective Counterfactual Explanations
(MOC)](https://link.springer.com/chapter/10.1007/978-3-030-58112-1_31).
## Available methods
The following counterfactual explanation methods are currently implemented:
- [Multi-Objective Counterfactual Explanations (MOC)](https://link.springer.com/chapter/10.1007/978-3-030-58112-1_31)
- [Nearest Instance Counterfactual Explanations (NICE)](https://arxiv.org/abs/2104.07411) (an extended version)
- [WhatIf](https://arxiv.org/abs/1907.04135) (an extended version)
## Installation
You can install the development version from [GitHub](https://github.com/) with:
``` r
# install.packages("devtools")
devtools::install_github("dandls/counterfactuals")
```
## Get started
In this example, we train a `randomForest` on the `iris` dataset and examine how a given `virginica` observation
would have to change in order to be classified as `versicolor`.
```{r example, message=FALSE}
library(counterfactuals)
library(randomForest)
library(iml)
```
### Fitting a model
First, we train a `randomForest` model to predict the target variable `Species`, omitting one observation from the training
data, which is `x_interest` (the observation $x^{\star}$ for which we want to find counterfactuals).
```{r}
rf = randomForest(Species ~ ., data = iris[-150L, ])
```
### Setting up an iml::Predictor() object
We then create an [`iml::Predictor`](https://christophm.github.io/iml/reference/Predictor.html) object, which serves as
a wrapper for different model types; it contains the model and the data for its analysis.
```{r}
predictor = Predictor$new(rf, type = "prob")
```
### Find counterfactuals
For `x_interest`, the model predicts a probability of 8% for class `versicolor`.
```{r}
x_interest = iris[150L, ]
predictor$predict(x_interest)
```
Now, we examine what needs to be changed in `x_interest` so that the model predicts a probability of at least 50% for class `versicolor`.
Here, we want to apply WhatIf and since it is a classification task, we create a `WhatIfClassif` object.
```{r}
wi_classif = WhatIfClassif$new(predictor, n_counterfactuals = 5L)
```
Then, we use the `find_counterfactuals()` method to find counterfactuals for `x_interest`.
```{r}
cfactuals = wi_classif$find_counterfactuals(
x_interest, desired_class = "versicolor", desired_prob = c(0.5, 1)
)
```
### The counterfactuals object
`cfactuals` is a `Counterfactuals` object that contains the counterfactuals and has several methods for their
evaluation and visualization.
```{r}
cfactuals
```
The counterfactuals are stored in the `data` field.
```{r}
cfactuals$data
```
With the `evaluate()` method, we can evaluate the counterfactuals using various quality measures.
```{r}
cfactuals$evaluate()
```
One visualization option is to plot the frequency of feature changes across all counterfactuals using the
`plot_freq_of_feature_changes()` method.
```{r, fig.height=2}
cfactuals$plot_freq_of_feature_changes()
```
Another visualization option is a parallel plot---created with the `plot_parallel()` method---that connects the (scaled)
feature values of each counterfactual and highlights `x_interest` in blue.
```{r, fig.height=2.5, message=FALSE}
cfactuals$plot_parallel()
```