Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
181 commits
Select commit Hold shift + click to select a range
aacbfc9
Add files via upload
CyrusZhang73 Sep 25, 2023
1c225c7
Update _distributions.py
CyrusZhang73 Sep 25, 2023
ed52ae2
Delete viabel/_mc_diagnostics.py
CyrusZhang73 Sep 25, 2023
e660c48
Add files via upload
CyrusZhang73 Sep 25, 2023
9ddf26e
Update _utils.py
CyrusZhang73 Sep 25, 2023
8af155e
Update approximations.py
CyrusZhang73 Sep 25, 2023
6c52ee8
Update models.py
CyrusZhang73 Sep 25, 2023
5aced28
Update objectives.py
CyrusZhang73 Sep 25, 2023
7e80d31
Update objectives.py
CyrusZhang73 Sep 25, 2023
3c42431
Update optimization.py
CyrusZhang73 Sep 25, 2023
e2ad4b5
Update test_convenience.py
CyrusZhang73 Sep 25, 2023
dfcf197
Update test_models.py
CyrusZhang73 Sep 25, 2023
497832d
Add files via upload
CyrusZhang73 Sep 25, 2023
54a4085
Update test_objectives.py
CyrusZhang73 Sep 25, 2023
4426c96
Update test_optimization.py
CyrusZhang73 Sep 25, 2023
4c3f78e
Update _utils.py
CyrusZhang73 Sep 25, 2023
12bd16d
Delete viabel/jax_paragami directory
CyrusZhang73 Sep 25, 2023
e78c030
Add files via upload
CyrusZhang73 Sep 25, 2023
cd8dd1c
Update approximations.py
CyrusZhang73 Sep 25, 2023
cf68eb7
Update approximations.py
CyrusZhang73 Sep 25, 2023
97c3f79
Update _utils.py
CyrusZhang73 Sep 25, 2023
e740013
Update objectives.py
CyrusZhang73 Sep 25, 2023
e198151
Update objectives.py
CyrusZhang73 Sep 25, 2023
87946b4
Update objectives.py
CyrusZhang73 Sep 25, 2023
25fe10c
Update test_objectives.py
CyrusZhang73 Sep 25, 2023
dc70fda
Update objectives.py
CyrusZhang73 Sep 25, 2023
c56166f
Update objectives.py
CyrusZhang73 Sep 25, 2023
8c6ffd1
Update requirements.txt
CyrusZhang73 Sep 27, 2023
401032f
Update requirements-docs.txt
CyrusZhang73 Sep 27, 2023
bcad736
Update requirements.txt
CyrusZhang73 Sep 27, 2023
ca81271
Update requirements-dev.txt
CyrusZhang73 Sep 27, 2023
6b87646
Update requirements.txt
CyrusZhang73 Sep 27, 2023
0c888e4
Update objectives.py
CyrusZhang73 Sep 27, 2023
30ad5eb
Update requirements.txt
CyrusZhang73 Sep 27, 2023
64c7d1f
Update optimization.py
CyrusZhang73 Sep 27, 2023
3a1b47f
Update objectives.py
CyrusZhang73 Sep 27, 2023
71f86f8
Update objectives.py
CyrusZhang73 Sep 27, 2023
0f2a371
Update objectives.py
CyrusZhang73 Sep 27, 2023
68d9f27
Update pattern_containers.py
CyrusZhang73 Sep 27, 2023
f69bea0
Update test_patterns.py
CyrusZhang73 Sep 27, 2023
86ccc83
Update test_models.py
CyrusZhang73 Sep 27, 2023
45d510f
Update _utils.py
CyrusZhang73 Sep 29, 2023
6c26423
Update optimization.py
CyrusZhang73 Sep 29, 2023
e61c5dd
Update _utils.py
CyrusZhang73 Sep 29, 2023
a32fc8e
Delete jax_paragami directory
CyrusZhang73 Sep 29, 2023
dfdfa28
Add files via upload
CyrusZhang73 Sep 29, 2023
4a395e7
Update MANIFEST.in
CyrusZhang73 Sep 29, 2023
5e031e5
Update setup.py
CyrusZhang73 Sep 29, 2023
0d1f58b
Update approximations.py
CyrusZhang73 Sep 29, 2023
439ec10
Update optimization.py
CyrusZhang73 Sep 29, 2023
d812f71
Update _utils.py
CyrusZhang73 Sep 29, 2023
3d9a471
Update models.py
CyrusZhang73 Sep 29, 2023
95cad86
Update optimization.py
CyrusZhang73 Oct 2, 2023
224ef79
Delete viabel/jax_paragami/__pycache__ directory
CyrusZhang73 Oct 3, 2023
f623872
Update objectives.py
CyrusZhang73 Oct 3, 2023
0b6aaae
Delete viabel/jax_paragami directory
CyrusZhang73 Oct 3, 2023
ac8a6d2
Update __init__.py
CyrusZhang73 Oct 3, 2023
58c65e5
Update approximations.py
CyrusZhang73 Oct 3, 2023
8ff9e22
Add files via upload
CyrusZhang73 Oct 3, 2023
843dbcd
Update setup.py
CyrusZhang73 Oct 3, 2023
5633736
Update MANIFEST.in
CyrusZhang73 Oct 3, 2023
15509e1
Add files via upload
CyrusZhang73 Oct 3, 2023
b6ee706
Update test_optimization.py
CyrusZhang73 Oct 3, 2023
acac0f2
Update pattern_containers.py
CyrusZhang73 Oct 3, 2023
72f0f88
Delete viabel/tests/test_patterns.py
CyrusZhang73 Oct 3, 2023
1ba5a97
Delete viabel/tests/test_functions.py
CyrusZhang73 Oct 3, 2023
be36451
Add files via upload
CyrusZhang73 Oct 3, 2023
118ff29
Update test_optimization.py
CyrusZhang73 Oct 3, 2023
aea91b2
Update _psis.py
CyrusZhang73 Oct 3, 2023
0802e39
Update optimization.py
CyrusZhang73 Oct 3, 2023
8385be6
Update convenience.py
CyrusZhang73 Oct 4, 2023
117c2f5
Update test_convenience.py
CyrusZhang73 Oct 4, 2023
d22311d
Update _utils.py
CyrusZhang73 Oct 6, 2023
01ede86
Update test_models.py
CyrusZhang73 Oct 6, 2023
aa0bd65
Update test_objectives.py
CyrusZhang73 Oct 6, 2023
60d27e2
Update convenience.py
CyrusZhang73 Oct 6, 2023
301b9b9
Update setup.py
CyrusZhang73 Oct 6, 2023
5a82388
Update convenience.py
CyrusZhang73 Oct 6, 2023
a6c66a4
Delete viabel/tests/test_model.stan
CyrusZhang73 Oct 6, 2023
690eb57
Delete viabel/tests/test_model.data.json
CyrusZhang73 Oct 6, 2023
8319eb8
Add files via upload
CyrusZhang73 Oct 6, 2023
0cc7f52
Update test_optimization.py
CyrusZhang73 Oct 6, 2023
82b1b3d
Update .readthedocs.yml
CyrusZhang73 Oct 6, 2023
1414a0f
Update convenience.py
CyrusZhang73 Oct 6, 2023
3c3e5d6
Delete viabel/stan_models directory
CyrusZhang73 Oct 10, 2023
55794ee
Create test_models.stan
CyrusZhang73 Oct 10, 2023
592e57c
Add files via upload
CyrusZhang73 Oct 10, 2023
5fabcb9
Delete viabel/test_model.data.json
CyrusZhang73 Oct 10, 2023
0d38e5b
Delete viabel/test_model.stan
CyrusZhang73 Oct 10, 2023
2664e25
Update test_models.py
CyrusZhang73 Oct 10, 2023
6580ba1
Update optimization.py
CyrusZhang73 Oct 10, 2023
4656007
Update optimization.py
CyrusZhang73 Oct 10, 2023
b49fdf1
Update MANIFEST.in
CyrusZhang73 Oct 10, 2023
e53542c
Update .readthedocs.yml
CyrusZhang73 Oct 10, 2023
c3ecc7d
Update .readthedocs.yml
CyrusZhang73 Oct 10, 2023
de1a86b
Update optimization.py
CyrusZhang73 Oct 10, 2023
b904952
Update test_models.py
CyrusZhang73 Oct 10, 2023
e2f29ef
Update test_models.py
CyrusZhang73 Oct 11, 2023
09eb883
Update test_models.py
CyrusZhang73 Oct 11, 2023
f80523f
Update test_models.py
CyrusZhang73 Oct 11, 2023
d84dbf8
Update test_models.py
CyrusZhang73 Oct 11, 2023
7740320
Rename test_models.stan to test_model.stan
CyrusZhang73 Oct 11, 2023
f0ec4df
Update test_models.py
CyrusZhang73 Oct 11, 2023
39e814b
Update test_convenience.py
CyrusZhang73 Oct 11, 2023
0d641df
Update optimization.py
CyrusZhang73 Oct 11, 2023
d134b9b
Update requirements.txt
CyrusZhang73 Oct 11, 2023
6328dce
Delete viabel/base_patterns.py
CyrusZhang73 Oct 11, 2023
f15d6e3
Delete viabel/numeric_array_patterns.py
CyrusZhang73 Oct 11, 2023
3c3613d
Delete viabel/pattern_containers.py
CyrusZhang73 Oct 11, 2023
b9e22b2
Delete viabel/psdmatrix_patterns.py
CyrusZhang73 Oct 11, 2023
2160b44
Delete viabel/simplex_patterns.py
CyrusZhang73 Oct 11, 2023
f649419
Update approximations.py
CyrusZhang73 Oct 11, 2023
92750dc
Add files via upload
CyrusZhang73 Oct 11, 2023
f8ae96d
Delete viabel/tests/test_functions.py
CyrusZhang73 Oct 11, 2023
01d22a1
Delete viabel/tests/test_patterns.py
CyrusZhang73 Oct 11, 2023
2040ddb
Add files via upload
CyrusZhang73 Oct 11, 2023
390727c
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
4929f4a
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
203ce7c
Update __init__.py
CyrusZhang73 Oct 11, 2023
9a22ba6
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
1b6e4ed
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
308a3a0
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
1991085
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
fe3cb31
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
bbd427b
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
9f10bd8
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
192a05d
Update .readthedocs.yml
CyrusZhang73 Oct 11, 2023
c82fda8
Update requirements.txt
CyrusZhang73 Oct 11, 2023
e745899
Update test_convenience.py
CyrusZhang73 Oct 11, 2023
017aeaa
Update optimization.py
CyrusZhang73 Oct 11, 2023
3e14eec
Update patterns.py
CyrusZhang73 Oct 11, 2023
baecec9
Update test_models.py
CyrusZhang73 Oct 11, 2023
b3b6910
Update test_convenience.py
CyrusZhang73 Oct 11, 2023
cccd66b
Update test_models.py
CyrusZhang73 Oct 12, 2023
fd99f46
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
77ab9c9
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
2d3a7ff
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
b25054b
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
1b5c6d8
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
b0e97a2
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
08f05fb
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
dfd5047
Update .travis.yml
CyrusZhang73 Oct 16, 2023
a99f1aa
Update .travis.yml
CyrusZhang73 Oct 16, 2023
ea7637c
Update test_optimization.py
CyrusZhang73 Oct 16, 2023
8972fe6
Update .travis.yml
CyrusZhang73 Oct 17, 2023
b595fad
Update .travis.yml
CyrusZhang73 Oct 17, 2023
eab26a0
Update .travis.yml
CyrusZhang73 Oct 17, 2023
26e3b2a
Update .travis.yml
CyrusZhang73 Oct 17, 2023
470af96
Update test_optimization.py
CyrusZhang73 Oct 17, 2023
c27e6fa
Update test_optimization.py
CyrusZhang73 Oct 17, 2023
8b1c70e
Update test_optimization.py
CyrusZhang73 Oct 17, 2023
d4b2bc2
Update .travis.yml
CyrusZhang73 Oct 17, 2023
fbd8ddf
Update .travis.yml
CyrusZhang73 Oct 17, 2023
1e7157e
Update test_optimization.py
CyrusZhang73 Oct 18, 2023
a56accd
Update optimization.py
CyrusZhang73 Oct 18, 2023
c225580
Update test_optimization.py
CyrusZhang73 Oct 18, 2023
957ead8
Update test_objectives.py
CyrusZhang73 Oct 18, 2023
bdaba59
Update objectives.py
CyrusZhang73 Oct 18, 2023
d83074d
Update test_objectives.py
CyrusZhang73 Oct 18, 2023
28f63b6
Update approximations.py
CyrusZhang73 Oct 18, 2023
e89c983
Update approximations.py
CyrusZhang73 Oct 19, 2023
29ae60b
Update objectives.py
CyrusZhang73 Oct 19, 2023
a4da3e1
Update approximations.py
CyrusZhang73 Oct 19, 2023
0ea0e81
Update objectives.py
CyrusZhang73 Oct 19, 2023
851e1c6
Update objectives.py
CyrusZhang73 Oct 19, 2023
283a134
Update test_objectives.py
CyrusZhang73 Oct 19, 2023
1efe233
Update test_objectives.py
CyrusZhang73 Oct 19, 2023
b0d33b4
Update approximations.py
CyrusZhang73 Oct 19, 2023
b4df4c9
Update test_objectives.py
CyrusZhang73 Oct 19, 2023
8376b24
Update test_optimization.py
CyrusZhang73 Oct 20, 2023
8de7241
Update optimization.py
CyrusZhang73 Oct 21, 2023
3976e01
Update optimization.py
CyrusZhang73 Oct 21, 2023
eddf500
Update requirements.txt
CyrusZhang73 Oct 21, 2023
4f91c7d
Update requirements-dev.txt
CyrusZhang73 Oct 21, 2023
f3bb1ba
Update requirements.txt
CyrusZhang73 Oct 24, 2023
79d8279
Update requirements-dev.txt
CyrusZhang73 Oct 24, 2023
37b5599
Update requirements-dev.txt
CyrusZhang73 Oct 24, 2023
e0db797
Update optimization.py
CyrusZhang73 Oct 24, 2023
2d2f9ca
Update test_optimization.py
CyrusZhang73 Oct 24, 2023
0264df0
I modified most of them.
CyrusZhang73 Nov 27, 2023
8248b6f
change function value_and_grad()
CyrusZhang73 Dec 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
# Required
version: 2

# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.11"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
Expand All @@ -19,7 +24,6 @@ sphinx:

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
install:
- method: pip
path: .[docs]
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ recursive-exclude * __pycache__
recursive-exclude * *.py[co]

include viabel/data/*.stan
include viabel/data/*.json

2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
codecov
coverage
pytest
pystan==2.19.1.1
pystan>=3.1.0

# lint
autoflake
Expand Down
1 change: 1 addition & 0 deletions requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ sphinx_rtd_theme
ipykernel
nbsphinx
nbstripout
bridgestan
13 changes: 7 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
numpy>=1.13
scipy~=1.7.3
autograd~=1.3
scipy>=1.7.3
jax>=0.4.1
tqdm~=4.64.1
paragami~=0.42

pystan~=2.19.1.1
bridgestan>=2.0.0
pystan>=3.1.0
pytest~=7.1.2
viabel~=0.5.1
setuptools~=65.5.0
setuptools~=65.5.0
jaxlib>=0.4.1
nest_asyncio>=1.5.8
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_long_description():
extras_require={ 'docs' : get_requirements_docs(),
'dev' : get_requirements_dev()
},
python_requires='>=3.8',
python_requires='>=3.9',
classifiers=['Programming Language :: Python :: 3',
'Natural Language :: English',
'License :: OSI Approved :: MIT License',
Expand Down
2 changes: 2 additions & 0 deletions viabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from viabel.models import *
from viabel.objectives import *
from viabel.optimization import *
from viabel.patterns import *
from viabel.function_patterns import *
6 changes: 3 additions & 3 deletions viabel/_distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import autograd.numpy as np
from autograd.numpy import linalg
from autograd.scipy import special, stats
import jax.numpy as np
from jax.numpy import linalg
from jax.scipy import special, stats


# See: https://github.com/scipy/scipy/blob/master/scipy/stats/_multivariate.py
Expand Down
20 changes: 11 additions & 9 deletions viabel/_mc_diagnostics.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

import autograd.numpy as np
import jax.numpy as np
from scipy.fftpack import next_fast_len


Expand Down Expand Up @@ -64,30 +64,30 @@ def ess(samples):

rho_hat_t = np.zeros(n_draw)
rho_hat_even = 1.0
rho_hat_t[0] = rho_hat_even
rho_hat_t =rho_hat_t.at[0].set(rho_hat_even)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add space after the equals sign

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing space after equals sign

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you even need the equality if you are using .set? This applies to lines 69, 77, 78, etc. too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because JAX arrays are immutable and .at[].set() does not modify the array in place but rather creates a new array with the specified change.

rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
rho_hat_t[1] = rho_hat_odd
rho_hat_t = rho_hat_t.at[1].set(rho_hat_odd)

# Geyer's initial positive sequence
t = 1
while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
if (rho_hat_even + rho_hat_odd) >= 0:
rho_hat_t[t + 1] = rho_hat_even
rho_hat_t[t + 2] = rho_hat_odd
rho_hat_t = rho_hat_t.at[t+1].set(rho_hat_even)
rho_hat_t = rho_hat_t.at[t + 2].set(rho_hat_odd)
t += 2

max_t = t - 2
# improve estimation
if rho_hat_even > 0:
rho_hat_t[max_t + 1] = rho_hat_even
rho_hat_t = rho_hat_t.at[max_t + 1].set(rho_hat_even)
# Geyer's initial monotone sequence
t = 1
while t <= max_t - 2:
if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
rho_hat_t[t + 2] = rho_hat_t[t + 1]
rho_hat_t = rho_hat_t.at[t + 1].set((rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0)
rho_hat_t = rho_hat_t.at[t + 2].set(rho_hat_t[t + 1])
t += 2

ess = n_chain * n_draw
Expand Down Expand Up @@ -117,6 +117,7 @@ def MCSE(sample):
n_iters, d = sample.shape
sd_dev = np.sqrt(np.var(sample, ddof=1, axis=0))
eff_samp = [ess(sample[:, i].reshape(1, n_iters)) for i in range(d)]
eff_samp = np.asarray(eff_samp)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine this line with the previous one

mcse = sd_dev / np.sqrt(eff_samp)
return eff_samp, mcse

Expand Down Expand Up @@ -178,7 +179,8 @@ def R_hat_convergence_check(samples, windows, Rhat_threshold=1.1):
best_W: `int`
Best window size
"""
R_hat_array = [np.max(compute_R_hat(np.array(samples[-window:]), 0)) for window in windows]
R_hat_array = [np.max(compute_R_hat(np.array(samples[-window:]), 0)) for window in windows]
R_hat_array = np.asarray(R_hat_array)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine this line with the previous one

best_R_hat_ind = np.argmin(R_hat_array)
success = R_hat_array[best_R_hat_ind] <= Rhat_threshold
return success, windows[best_R_hat_ind]
1 change: 1 addition & 0 deletions viabel/_psis.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def psislw(lw, Reff=1.0, overwrite_lw=False):
Pareto tail indices

"""
lw = np.array(lw)
if lw.ndim == 2:
n, m = lw.shape
elif lw.ndim == 1:
Expand Down
51 changes: 1 addition & 50 deletions viabel/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import os
import pickle
import shutil
import time
from hashlib import md5

import autograd.numpy as np
import pystan
import numpy as np


def vectorize_if_needed(f, a, axis=-1):
Expand Down Expand Up @@ -36,47 +30,4 @@ def __exit__(self, *args):
self.interval = self.end - self.start


def _data_file_path(filename):
"""Returns the path to an internal file"""
return os.path.abspath(os.path.join(__file__, '../stan_models', filename))


def _stan_model_cache_dir():
return _data_file_path('cached-stan-models')


def clear_stan_model_cache():
stan_model_dir = _stan_model_cache_dir()
if os.path.exists(stan_model_dir):
shutil.rmtree(stan_model_dir)


def StanModel_cache(model_code=None, model_name=None, **kwargs):
"""Use just as you would `StanModel`"""
if model_code is None:
if model_name is None:
raise ValueError('Either model_code or model_name must be provided')
model_file = _data_file_path(model_name + '.stan')
if not os.path.isfile(model_file):
raise ValueError('invalid model "{}"'.format(model_name))
with open(model_file) as f:
model_code = f.read()
stan_model_dir = _stan_model_cache_dir()
os.makedirs(stan_model_dir, exist_ok=True)
code_hash = md5(model_code.encode('ascii')).hexdigest()
if model_name is None:
cache_fn = 'cached-model-{}.pck'.format(code_hash)
else:
cache_fn = 'cached-{}-{}.pck'.format(model_name, code_hash)
cache_file = os.path.join(stan_model_dir, cache_fn)
if os.path.exists(cache_file):
print('Using cached StanModel{}'.format('' if model_name is None
else ' for ' + model_name))
with open(cache_file, 'rb') as f:
sm = pickle.load(f)
else:
sm = pystan.StanModel(model_code=model_code, model_name=model_name)
with open(cache_file, 'wb') as f:
pickle.dump(sm, f)

return sm
34 changes: 19 additions & 15 deletions viabel/approximations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from abc import ABC, abstractmethod
import scipy

import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.norm as norm
import autograd.scipy.stats.t as t_dist
from autograd import elementwise_grad
from autograd.scipy.linalg import sqrtm
from paragami import (
FlattenFunctionInput, NumericArrayPattern, NumericVectorPattern, PatternDict,
PSDSymmetricMatrixPattern)
import jax.numpy as np
import numpy.random as npr
import jax.scipy.stats.norm as norm
import jax.scipy.stats.t as t_dist
from jax import jvp
from viabel.function_patterns import FlattenFunctionInput
from viabel.patterns import NumericArrayPattern, NumericVectorPattern, PatternDict, PSDSymmetricMatrixPattern

from ._distributions import multivariate_t_logpdf
from viabel._distributions import multivariate_t_logpdf

__all__ = [
'ApproximationFamily',
Expand Down Expand Up @@ -318,7 +317,10 @@ def _get_mu_sigma_pattern(dim):
ms_pattern['Sigma'] = PSDSymmetricMatrixPattern(size=dim)
return ms_pattern


def sqrtm(matrix):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cholesky instead

L = scipy.linalg.cholesky(matrix)
return L

class MultivariateT(ApproximationFamily):
"""A full-rank multivariate t approximation family."""

Expand Down Expand Up @@ -415,17 +417,19 @@ def __init__(self, layers_shapes, nonlinearity=np.tanh, last=np.tanh,

def forward(self, var_param, x):
log_det_J = np.zeros(x.shape[0])
derivative = elementwise_grad(self._nonlinearity)
derivative_last = elementwise_grad(self._last)
for layer_id in range(self._layers):
W = var_param[str(layer_id)]
b = var_param[str(layer_id) + "_b"]
if layer_id + 1 == self._layers:
x = self._last(np.dot(x, W) + b)
log_det_J += np.log(np.abs(np.dot(derivative_last(x), W.T).sum(axis=1)))
_, elementwise_gradient_last = jvp(self._last, (x,), (np.ones_like(x),))

log_det_J += np.log(np.abs(np.dot(elementwise_gradient_last, W.T).sum(axis=1)))
else:
x = self._nonlinearity(np.dot(x, W) + b)
log_det_J += np.log(np.abs(np.dot(derivative(x), W.T).sum(axis=1)))
_, elementwise_gradient = jvp(self._nonlinearity, (x,), (np.ones_like(x),))

log_det_J += np.log(np.abs(np.dot(elementwise_gradient, W.T).sum(axis=1)))
return x, log_det_J

def sample(self, var_param, n_samples):
Expand Down
4 changes: 2 additions & 2 deletions viabel/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from viabel.diagnostics import all_diagnostics
from viabel.models import Model, StanModel
from viabel.objectives import ExclusiveKL
from viabel.optimization import RAABBVI, FASO, RMSProp
from viabel.optimization import RAABBVI, FASO, RMSProp, AveragedRMSProp

all = [
'bbvi',
Expand Down Expand Up @@ -80,7 +80,7 @@ def bbvi(dimension, *, n_iters=10000, num_mc_samples=10, log_density=None,
objective = ExclusiveKL(approx, model, num_mc_samples)
if init_var_param is None:
init_var_param = approx.init_param()
base_opt = RMSProp(learning_rate, diagnostics=True, **RMS_kwargs)
base_opt = AveragedRMSProp(learning_rate, diagnostics=True, **RMS_kwargs)
if adaptive and not fixed_lr:
opt = RAABBVI(base_opt, **RAABBVI_kwargs)
elif adaptive and fixed_lr:
Expand Down
Loading