Skip to content

Commit 263832a

Browse files
committed
Updated
1 parent 299bfc0 commit 263832a

7 files changed

+424
-100
lines changed

algorithms.py

+47-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from numba import njit
33
import sobol_seq
4-
import kernel_herding
4+
import kernel_methods
55
from itertools import count
66
from math import cos, gamma, pi, sin, sqrt
77
from typing import Callable, Iterator, List
@@ -84,14 +84,16 @@ def owen_complement(X_background, X_foreground, predict_function, n_samples):
8484

8585

8686
@njit
87-
def _accumulate_samples_castro(phi, predictions, j):
87+
def _accumulate_samples_castro(phi, predictions, j, weights=None):
88+
if weights == None:
89+
weights = np.full(predictions.shape[1], 1 / predictions.shape[1])
8890
for foreground_idx in range(predictions.shape[0]):
8991
for sample_idx in range(predictions.shape[1]):
9092
phi[foreground_idx, j[sample_idx]] += predictions[foreground_idx][
91-
sample_idx] / predictions.shape[1]
93+
sample_idx] * weights[sample_idx]
9294

9395

94-
def estimate_shap_given_permutations(X_background, X_foreground, predict_function, p):
96+
def estimate_shap_given_permutations(X_background, X_foreground, predict_function, p, weights=None):
9597
n_features = X_background.shape[1]
9698
phi = np.zeros((X_foreground.shape[0], n_features))
9799
n_permutations = p.shape[0]
@@ -105,7 +107,7 @@ def estimate_shap_given_permutations(X_background, X_foreground, predict_functio
105107
predictions = (pred_on - pred_off).reshape(
106108
(X_foreground.shape[0], mask.shape[0], X_background.shape[0]))
107109
predictions = np.mean(predictions, axis=2)
108-
_accumulate_samples_castro(phi, predictions, j)
110+
_accumulate_samples_castro(phi, predictions, j, weights)
109111
pred_off = pred_on
110112

111113
return phi
@@ -123,6 +125,28 @@ def monte_carlo(X_background, X_foreground, predict_function, n_samples):
123125
return estimate_shap_given_permutations(X_background, X_foreground, predict_function, p)
124126

125127

128+
def monte_carlo_weighted(X_background, X_foreground, predict_function, n_samples):
129+
n_features = X_background.shape[1]
130+
assert n_samples % (n_features + 1) == 0
131+
# castro is allowed to take 2 * more samples than owen as it reuses predictions
132+
samples_per_feature = 2 * (n_samples // (n_features + 1))
133+
p = np.zeros((samples_per_feature, n_features), dtype=np.int64)
134+
for i in range(samples_per_feature):
135+
p[i] = np.random.permutation(n_features)
136+
weights = kernel_methods.compute_bayesian_weights(p, kernel_methods.kt_kernel)
137+
return estimate_shap_given_permutations(X_background, X_foreground, predict_function, p,
138+
weights)
139+
140+
141+
def sbq(X_background, X_foreground, predict_function, n_samples):
142+
n_features = X_background.shape[1]
143+
assert n_samples % (n_features + 1) == 0
144+
samples_per_feature = 2 * (n_samples // (n_features + 1))
145+
p, w = kernel_methods.sequential_bayesian_quadrature(samples_per_feature, n_features)
146+
return estimate_shap_given_permutations(X_background, X_foreground, predict_function, p,
147+
w)
148+
149+
126150
def monte_carlo_antithetic(X_background, X_foreground, predict_function, n_samples):
127151
n_features = X_background.shape[1]
128152

@@ -157,6 +181,7 @@ def sobol_sphere_permutations(n_samples, n_features):
157181

158182
return np.argsort(sobol, axis=1)
159183

184+
160185
# sample with l ones and i off
161186
def draw_castro_stratified_samples(n_samples, n_features, i, l):
162187
mask = np.zeros((n_samples, n_features - 1), dtype=bool)
@@ -272,7 +297,7 @@ def kt_herding(X_background, X_foreground, predict_function, n_samples):
272297
assert n_samples % (n_features + 1) == 0
273298
# castro is allowed to take 2 * more samples than owen as it reuses predictions
274299
samples_per_feature = 2 * (n_samples // (n_features + 1))
275-
p = kernel_herding.kt_herding_permutations(samples_per_feature, n_features)
300+
p = kernel_methods.kt_herding_permutations(samples_per_feature, n_features)
276301
return estimate_shap_given_permutations(X_background, X_foreground, predict_function, p)
277302

278303

@@ -330,6 +355,16 @@ def orthogonal(X_background, X_foreground, predict_function, n_samples):
330355
return estimate_shap_given_permutations(X_background, X_foreground, predict_function, p)
331356

332357

358+
def orthogonal_weighted(X_background, X_foreground, predict_function, n_samples):
359+
n_features = X_background.shape[1]
360+
assert n_samples % (2 * (n_features + 1)) == 0
361+
# castro is allowed to take 2 * more samples than owen as it reuses predictions
362+
samples_per_feature = 2 * (n_samples // (n_features + 1))
363+
p = _orthogonal_permutations(samples_per_feature, n_features)
364+
w = kernel_methods.compute_bayesian_weights(p, kernel_methods.kt_kernel)
365+
return estimate_shap_given_permutations(X_background, X_foreground, predict_function, p, w)
366+
367+
333368
def _int_sin_m(x: float, m: int) -> float:
334369
"""Computes the integral of sin^m(t) dt from 0 to x recursively"""
335370
if m == 0:
@@ -438,6 +473,10 @@ def fibonacci(X_background, X_foreground, predict_function, n_samples):
438473
def min_sample_size(alg, n_features):
439474
if alg == monte_carlo:
440475
return n_features + 1
476+
elif alg == monte_carlo_weighted:
477+
return n_features + 1
478+
elif alg == sbq:
479+
return n_features + 1
441480
elif alg == qmc_sobol:
442481
return n_features + 1
443482
elif alg == fibonacci:
@@ -450,6 +489,8 @@ def min_sample_size(alg, n_features):
450489
return 2 * (n_features + 1)
451490
elif alg == orthogonal:
452491
return 2 * (n_features + 1)
492+
elif alg == orthogonal_weighted:
493+
return 2 * (n_features + 1)
453494
elif alg == owen or alg == owen_complement:
454495
return n_features * 4
455496
elif alg == castro_stratified:

experiments.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,21 @@ def plot_experiments():
2121
repeats = 25
2222
foreground_examples = 10
2323
background_examples = 100
24-
max_evals = 100000
24+
max_evals = 5000
2525
datasets_set = {
2626
"make_regression": datasets.get_regression(foreground_examples, background_examples),
2727
"cal_housing": datasets.get_cal_housing(foreground_examples, background_examples),
2828
"adult": datasets.get_adult(foreground_examples, background_examples),
2929
"breast_cancer": datasets.get_breast_cancer(foreground_examples, background_examples),
3030
}
3131
algorithms_set = {
32-
"Castro": algorithms.monte_carlo,
33-
"Castro-Complement": algorithms.monte_carlo_antithetic,
34-
# "Castro-LHS": algorithms.castro_lhs,
35-
}
36-
algorithms_set = {
37-
# "Castro": algorithms.castro,
38-
"Castro-Orthogonal": algorithms.orthogonal,
39-
"Castro-Complement": algorithms.monte_carlo_antithetic,
40-
"Fibonacci": algorithms.fibonacci,
32+
# "MC": algorithms.monte_carlo,
33+
# "Bayesian-MC": algorithms.monte_carlo_weighted,
34+
"SBQ": algorithms.sbq,
35+
# "MC-Orthogonal-Bayesian": algorithms.orthogonal_weighted,
36+
"MC-Orthogonal": algorithms.orthogonal,
37+
# "Castro-Complement": algorithms.monte_carlo_antithetic,
38+
# "Fibonacci": algorithms.fibonacci,
4139
# "Castro-ControlVariate": algorithms.castro_control_variate,
4240
# "Castro-QMC": algorithms.castro_qmc,
4341
# "KT-Herding": algorithms.kt_herding,

kernel_herding.py

-76
This file was deleted.

0 commit comments

Comments
 (0)