-
Notifications
You must be signed in to change notification settings - Fork 15
Switching to Bridgestan and JAX #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention:
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## master #58 +/- ##
==========================================
- Coverage 92.73% 87.20% -5.53%
==========================================
Files 6 8 +2
Lines 1211 2228 +1017
==========================================
+ Hits 1123 1943 +820
- Misses 88 285 +197 ☔ View full report in Codecov by Sentry. |
viabel/convenience.py
Outdated
|
|
||
|
|
||
| def bbvi(dimension, *, n_iters=10000, num_mc_samples=10, log_density=None, | ||
| def bbvi(dimension, *, n_iters=3000, num_mc_samples=10, log_density=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't change defaults; if you need it to be different for the test, change it manually in the test
viabel/convenience.py
Outdated
|
|
||
|
|
||
| def vi_diagnostics(var_param, *, objective=None, model=None, approx=None, n_samples=100000): | ||
| def vi_diagnostics(var_param, *, objective=None, model=None, approx=None, n_samples=3000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here; don't change defaults
| rho_hat_t = np.zeros(n_draw) | ||
| rho_hat_even = 1.0 | ||
| rho_hat_t[0] = rho_hat_even | ||
| rho_hat_t =rho_hat_t.at[0].set(rho_hat_even) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add space after the equals sign
| return ms_pattern | ||
|
|
||
|
|
||
| def sqrtm(matrix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use cholesky instead
| @@ -0,0 +1,393 @@ | |||
| import copy | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
credit paragami
viabel/optimization.py
Outdated
| # if descent_dir_history is not None: | ||
| # results['descent_dir_history'] = descent_dir_history | ||
| results_dict = {d: np.array(h) for d, h in results.items()} | ||
| results_dict = {d: jnp.array(h) for d, h in results.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use numpy here and below?
viabel/optimization.py
Outdated
| fit = model.sampling(data=data, init=init, iter=1000, chains=n_chains, | ||
| control=dict(adapt_delta=0.98)) #sampling from the model | ||
| model = stan.build(program_code=model_code, data=data) | ||
| fit = model.sample(num_chains=n_chains, num_samples=1000,init = init) # sampling from the model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this really a fit? Or should it be called something else?
viabel/tests/test_convenience.py
Outdated
| @@ -1,7 +1,7 @@ | |||
| import autograd.numpy as anp | |||
| import jax.numpy as jnp | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logical ordering of imports
viabel/tests/test_convenience.py
Outdated
| @@ -1,7 +1,7 @@ | |||
| import autograd.numpy as anp | |||
| import jax.numpy as jnp | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to use Jax for everything here? seems only necessary if you are going to autodiff. Same question applies to other test files
viabel/tests/test_models.py
Outdated
| def _test_model(m, x, supports_tempering, supports_constrain): | ||
| check_vjp(m, x) | ||
| check_vjp(m, x[0]) | ||
| #check_vjp(m, (x,), modes=['rev'], order=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete these comments?
jhuggins
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good but the tests aren't passing
|
@CyrusZhang73 just checking whether you are still planning to work on this PR? I'm asking as @charlesm93 is interested in testing Viabel with BridgeStan. |
No description provided.