-
Notifications
You must be signed in to change notification settings - Fork 15
Switching to Bridgestan and JAX #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
aacbfc9
1c225c7
ed52ae2
e660c48
9ddf26e
8af155e
6c52ee8
5aced28
7e80d31
3c42431
e2ad4b5
dfcf197
497832d
54a4085
4426c96
4c3f78e
12bd16d
e78c030
cd8dd1c
cf68eb7
97c3f79
e740013
e198151
87946b4
25fe10c
dc70fda
c56166f
8c6ffd1
401032f
bcad736
ca81271
6b87646
0c888e4
30ad5eb
64c7d1f
3a1b47f
71f86f8
0f2a371
68d9f27
f69bea0
86ccc83
45d510f
6c26423
e61c5dd
a32fc8e
dfdfa28
4a395e7
5e031e5
0d1f58b
439ec10
d812f71
3d9a471
95cad86
224ef79
f623872
0b6aaae
ac8a6d2
58c65e5
8ff9e22
843dbcd
5633736
15509e1
b6ee706
acac0f2
72f0f88
1ba5a97
be36451
118ff29
aea91b2
0802e39
8385be6
117c2f5
d22311d
01ede86
aa0bd65
60d27e2
301b9b9
5a82388
a6c66a4
690eb57
8319eb8
0cc7f52
82b1b3d
1414a0f
3c3e5d6
55794ee
592e57c
5fabcb9
0d38e5b
2664e25
6580ba1
4656007
b49fdf1
e53542c
c3ecc7d
de1a86b
b904952
e2f29ef
09eb883
f80523f
d84dbf8
7740320
f0ec4df
39e814b
0d641df
d134b9b
6328dce
f15d6e3
3c3613d
b9e22b2
2160b44
f649419
92750dc
f8ae96d
01d22a1
2040ddb
390727c
4929f4a
203ce7c
9a22ba6
1b6e4ed
308a3a0
1991085
fe3cb31
bbd427b
9f10bd8
192a05d
c82fda8
e745899
017aeaa
3e14eec
baecec9
b3b6910
cccd66b
fd99f46
77ab9c9
2d3a7ff
b25054b
1b5c6d8
b0e97a2
08f05fb
dfd5047
a99f1aa
ea7637c
8972fe6
b595fad
eab26a0
26e3b2a
470af96
c27e6fa
8b1c70e
d4b2bc2
fbd8ddf
1e7157e
a56accd
c225580
957ead8
bdaba59
d83074d
28f63b6
e89c983
29ae60b
a4da3e1
0ea0e81
851e1c6
283a134
1efe233
b0d33b4
b4df4c9
8376b24
8de7241
3976e01
eddf500
4f91c7d
f3bb1ba
79d8279
37b5599
e0db797
2d2f9ca
0264df0
8248b6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| codecov | ||
| coverage | ||
| pytest | ||
| pystan==2.19.1.1 | ||
| pystan>=3.1.0 | ||
|
|
||
| # lint | ||
| autoflake | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,3 +9,4 @@ sphinx_rtd_theme | |
| ipykernel | ||
| nbsphinx | ||
| nbstripout | ||
| bridgestan | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,11 @@ | ||
| numpy>=1.13 | ||
| scipy~=1.7.3 | ||
| autograd~=1.3 | ||
| scipy>=1.7.3 | ||
| jax>=0.4.1 | ||
| tqdm~=4.64.1 | ||
| paragami~=0.42 | ||
|
|
||
| pystan~=2.19.1.1 | ||
| bridgestan>=2.0.0 | ||
| pystan>=3.1.0 | ||
| pytest~=7.1.2 | ||
| viabel~=0.5.1 | ||
| setuptools~=65.5.0 | ||
| setuptools~=65.5.0 | ||
| jaxlib>=0.4.1 | ||
| nest_asyncio>=1.5.8 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import warnings | ||
|
|
||
| import autograd.numpy as np | ||
| import jax.numpy as np | ||
| from scipy.fftpack import next_fast_len | ||
|
|
||
|
|
||
|
|
@@ -64,30 +64,30 @@ def ess(samples): | |
|
|
||
| rho_hat_t = np.zeros(n_draw) | ||
| rho_hat_even = 1.0 | ||
| rho_hat_t[0] = rho_hat_even | ||
| rho_hat_t =rho_hat_t.at[0].set(rho_hat_even) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing space after equals sign
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you even need the equality if you are using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because JAX arrays are immutable and |
||
| rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus | ||
| rho_hat_t[1] = rho_hat_odd | ||
| rho_hat_t = rho_hat_t.at[1].set(rho_hat_odd) | ||
|
|
||
| # Geyer's initial positive sequence | ||
| t = 1 | ||
| while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0: | ||
| rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus | ||
| rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus | ||
| if (rho_hat_even + rho_hat_odd) >= 0: | ||
| rho_hat_t[t + 1] = rho_hat_even | ||
| rho_hat_t[t + 2] = rho_hat_odd | ||
| rho_hat_t = rho_hat_t.at[t+1].set(rho_hat_even) | ||
| rho_hat_t = rho_hat_t.at[t + 2].set(rho_hat_odd) | ||
| t += 2 | ||
|
|
||
| max_t = t - 2 | ||
| # improve estimation | ||
| if rho_hat_even > 0: | ||
| rho_hat_t[max_t + 1] = rho_hat_even | ||
| rho_hat_t = rho_hat_t.at[max_t + 1].set(rho_hat_even) | ||
| # Geyer's initial monotone sequence | ||
| t = 1 | ||
| while t <= max_t - 2: | ||
| if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]): | ||
| rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0 | ||
| rho_hat_t[t + 2] = rho_hat_t[t + 1] | ||
| rho_hat_t = rho_hat_t.at[t + 1].set((rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0) | ||
| rho_hat_t = rho_hat_t.at[t + 2].set(rho_hat_t[t + 1]) | ||
| t += 2 | ||
|
|
||
| ess = n_chain * n_draw | ||
|
|
@@ -117,6 +117,7 @@ def MCSE(sample): | |
| n_iters, d = sample.shape | ||
| sd_dev = np.sqrt(np.var(sample, ddof=1, axis=0)) | ||
| eff_samp = [ess(sample[:, i].reshape(1, n_iters)) for i in range(d)] | ||
| eff_samp = np.asarray(eff_samp) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. combine this line with the previous one |
||
| mcse = sd_dev / np.sqrt(eff_samp) | ||
| return eff_samp, mcse | ||
|
|
||
|
|
@@ -178,7 +179,8 @@ def R_hat_convergence_check(samples, windows, Rhat_threshold=1.1): | |
| best_W: `int` | ||
| Best window size | ||
| """ | ||
| R_hat_array = [np.max(compute_R_hat(np.array(samples[-window:]), 0)) for window in windows] | ||
| R_hat_array = [np.max(compute_R_hat(np.array(samples[-window:]), 0)) for window in windows] | ||
| R_hat_array = np.asarray(R_hat_array) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. combine this line with the previous one |
||
| best_R_hat_ind = np.argmin(R_hat_array) | ||
| success = R_hat_array[best_R_hat_ind] <= Rhat_threshold | ||
| return success, windows[best_R_hat_ind] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,15 @@ | ||
| from abc import ABC, abstractmethod | ||
| import scipy | ||
|
|
||
| import autograd.numpy as np | ||
| import autograd.numpy.random as npr | ||
| import autograd.scipy.stats.norm as norm | ||
| import autograd.scipy.stats.t as t_dist | ||
| from autograd import elementwise_grad | ||
| from autograd.scipy.linalg import sqrtm | ||
| from paragami import ( | ||
| FlattenFunctionInput, NumericArrayPattern, NumericVectorPattern, PatternDict, | ||
| PSDSymmetricMatrixPattern) | ||
| import jax.numpy as np | ||
| import numpy.random as npr | ||
| import jax.scipy.stats.norm as norm | ||
| import jax.scipy.stats.t as t_dist | ||
| from jax import jvp | ||
| from viabel.function_patterns import FlattenFunctionInput | ||
| from viabel.patterns import NumericArrayPattern, NumericVectorPattern, PatternDict, PSDSymmetricMatrixPattern | ||
|
|
||
| from ._distributions import multivariate_t_logpdf | ||
| from viabel._distributions import multivariate_t_logpdf | ||
|
|
||
| __all__ = [ | ||
| 'ApproximationFamily', | ||
|
|
@@ -318,7 +317,10 @@ def _get_mu_sigma_pattern(dim): | |
| ms_pattern['Sigma'] = PSDSymmetricMatrixPattern(size=dim) | ||
| return ms_pattern | ||
|
|
||
|
|
||
| def sqrtm(matrix): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use cholesky instead |
||
| L = scipy.linalg.cholesky(matrix) | ||
| return L | ||
|
|
||
| class MultivariateT(ApproximationFamily): | ||
| """A full-rank multivariate t approximation family.""" | ||
|
|
||
|
|
@@ -415,17 +417,19 @@ def __init__(self, layers_shapes, nonlinearity=np.tanh, last=np.tanh, | |
|
|
||
| def forward(self, var_param, x): | ||
| log_det_J = np.zeros(x.shape[0]) | ||
| derivative = elementwise_grad(self._nonlinearity) | ||
| derivative_last = elementwise_grad(self._last) | ||
| for layer_id in range(self._layers): | ||
| W = var_param[str(layer_id)] | ||
| b = var_param[str(layer_id) + "_b"] | ||
| if layer_id + 1 == self._layers: | ||
| x = self._last(np.dot(x, W) + b) | ||
| log_det_J += np.log(np.abs(np.dot(derivative_last(x), W.T).sum(axis=1))) | ||
| _, elementwise_gradient_last = jvp(self._last, (x,), (np.ones_like(x),)) | ||
|
|
||
| log_det_J += np.log(np.abs(np.dot(elementwise_gradient_last, W.T).sum(axis=1))) | ||
| else: | ||
| x = self._nonlinearity(np.dot(x, W) + b) | ||
| log_det_J += np.log(np.abs(np.dot(derivative(x), W.T).sum(axis=1))) | ||
| _, elementwise_gradient = jvp(self._nonlinearity, (x,), (np.ones_like(x),)) | ||
|
|
||
| log_det_J += np.log(np.abs(np.dot(elementwise_gradient, W.T).sum(axis=1))) | ||
| return x, log_det_J | ||
|
|
||
| def sample(self, var_param, n_samples): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add space after the equals sign