Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example: Bayes NN with steinVI (pyro-ppl#1297)
* added stein example * added test case * added stein bnn to docs. * moveed stein_bnn to other inf algs in docs * Added correct plating for model in `stein_bnn.py`. Works with latest pyro-ppl#833. * Add some doctests to transforms (pyro-ppl#1300) * add some doctest to transforms * make format * Tutorial for truncated distributions (pyro-ppl#1272) * WIP Do not merge. Tutorial for truncated distributions * WIP: Completed a few todos and fixed a few typos * WIP: Completed main sections. References and part 5 still pending * Added section on built in distributions and folded distributions * Draft ready * Remove M1-related warning from cell output * Truncated distributions tutorial added to index * Wrap latex equations in double dollar sign * Fix broken markdown equations * Added more details on folded distribs. Re-arranged sections. * Test: Change title level. * Links now point to the docs instead of the source code. Fixed some broken formatting of the titles. Use different seeds for Prior/Inference/Prediction. Changed models for inferring the truncation. Fixed minor typos. * Install numpyro and upgrade jax, jaxlib and matplotlib Copy jax arrays before passing to matplotlib functions * Clarified statement about the log_prob method in the TruncatedDistribution class. * Changed intro sentence to include folded distributions. * Remove command for installing jax. Use np.unique instead of jnp.unique * Cast rate parameter to float (pyro-ppl#1301) * Make potential_fn_gen and postprocess_fn_gen picklable (pyro-ppl#1302) * add wrapper * Make potential_fn_gen postprocess_fn_gen pickable * Stein based inference (pyro-ppl#833) * Added stein interface. * Fixed style and removed from VI baseclass. * Added reinit_guide.py * Added license. * added examples * Added examples. * Fixed some linting and LDA example; need to refactor wrapped_guide. * Added param site also get rng_keys; this should be reworked! * Removed datasets and fixed lda to running. * Fixed dimensionality bug for simplex support. * Added code from refactor/einstein * Fixed notebooks; todo: comment notebook. * Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly. * Started testing. * Removed assert from test_init_strategy. * Skeleton test_stein.py * Updated `test_stein/test_init` * Added test_params and likelihood computation to lda. * Fix init in MixtureKernel * Notebook fixes * debugging log likelihood * WIP, move benchmarks to datasets * trace guide to compute likelihood in lda. * Debugging LDA * Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate` * Cleaned test covered by `test_get_params`. * Added skeleton and finished _param_size test. * Fix LR example * IRIS LR * Fix Toy examples * Added pinfo test. * moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel` * Ran black and removed lambdas from KERNEL_TEST_CASE. * Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated). * Added skelelton test for test_score_sp_mcmc. * Fixed overwriting kval in `test_stein.test_apply_kernel` * Fixed lint * Fixed lint. * Added stein_loss test. * Factored vi source and test_vi out of einstein. * updated with black. * Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO) * Added perplexity to LDA. * Fixed log position for perplexity. * Refactored callbacks and added `test_checkpoint`. * Fixed imports * Reverted LDA to working version. * Added callback tests * Return loss history for `stein.run` * Added visual to LDA. * Fixed return for `run_lda` * Added missing topic num 20. * Added todo * Cleaned 1d_mm stein notebook. * Updated 2d gaussian notebook. * Add description to SVGD. * Updated `RBF_kernel` to work with one particle and added kernels notebook. * Fixed bug in bandwidth of RBF_kernel * SVI reproducing result from SVGD paper. * Better learning rate for SVI. * larger network * Updated predictive to allow for particle methods. * Removed TODO and fixed learning rate. * EinStein out performance SVGD * Latest working. * Fixed VI without progressbar. * Fixed mini batching for VI. * Added kernel visualization. * Init to sample for bayesian networks. * TODO predict shape. * Added scaling to plate primitive. * Fixed enumeration in Stein and added subsample_scale to funsor.plate. * Debugging LDA * Debugging lda * Debugging merge. * Updated jacobian computation in Stein. * Fixed issue with nested parameters for stein grad. * Fixed issue with nested parameters for stein grad. * Added NLL to DMM and predictions. * renaming and removing benchmark code * Cleaning branch from benchmarking. * Removed prediction from DMM. * Changed to syntax from older python * Fixed lint. * Fixed reinit warning for `init_to_uniform`. * updated to use black[jupyter] * Added licenses. * Added smoke a smoke test for SteinVI * Factored out Stein point MCMC * Factored out VI from EinStein. * Updated stein_kernels.ipynb and removed debugging pred_prey. * Removed `mixture.py` use `mixtures.py` instead. * Fixed lint. * Added examples to docs build. * Fixed stale import in `hmc.py` docstring. * Removed stein point test cases. * Changed Predictive to only check for guided models with particles. * Fixed lint. * Changed `reinit_guide` to add rng_keys for reinitialization. * Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`. * Removed empty line from `test_examples.py`. * Fixed lint. * Changed `stein_mixture_dmm.py` to use new signature and run method. * Added some comments and fixed `stein_mixture_dmm.py` to use new signature. * Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`. * Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature. * Parameterize `gru_dim` in `stein_mixture_dmm.py`. * Fixed steinvi to use pyro-ppl#1263; TODO: update examples. * Removed init_with_noise. * removed stein_bnn and changed `examples/datasets.py` to upstream * removed examples * Removed stein examples from docs. * renamed einstein.utils to einstein.util. * updated testing * Changed test to use auto_guide `init_loc_fn`. * removed `numpyro/util/ravel_pytree` * removed unused imports in numpyro/util * Added initialization to kernels in `test_einstein_kernels.py` * Changed kernel test to use np.arrays at global level. * change jnp arrays to np np array in tests. reverted subsample scale. * added docstring to `einstein/util/batch_ravel_pytree` Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]> * Improve subsample warning keys (pyro-ppl#1303) * Add ProvenanceArray to infer relational structure in a model (pyro-ppl#1248) * Add provenance array * Add tests for provenance * run make format * Workaround not be able to eval_shape a distribution * Make license * add a clearer guide for render a model with scan * fix failing bugs in recent jax release * Fix further failing tests * Make sure to be able to render ImproperUniform and random initialized params * port get_dependencies to numpyro * tighten test_improper_normal bound (pyro-ppl#1307) * Fix HMCECS multiple plates (pyro-ppl#1305) * Add Kumaraswamy and relaxed Bernoulli distributions (pyro-ppl#1283) * Add kumaraswamy and relaxed bernoulli distributions * clean up the flag * Require logits to be keyword argument * make relaxed bernoulli have the same signature as Pyro * fix docs build * Fix rsample bug * add more simple test for Kumaraswamy * Add various KL divergences for Gamma/Beta families (pyro-ppl#1284) * Add new distributions and kl * Add kumaraswamy and relaxed bernoulli distributions * clean up the flag * Require logits to be keyword argument * make relaxed bernoulli have the same signature as Pyro * fix docs build * Fix rsample bug * move the flag to Kumaraswamy class for convenient * Add loose strategy for missing plates in MCMC (pyro-ppl#1304) * Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Fix failing tests * skip prodlda test on CI * Bump to 0.9.0 (pyro-ppl#1310) * Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Bump to version 0.9.0 * Fix failing tests * Fix warnings in tests/examples * relax funsor requirement * Move optax_to_numpyro to optim * skip prodlda test on CI * added dimensions to plate and sqrt precision. * fixed/added comments in stein_bnn.py and removed lr datasets. * added comment to stein_bnn.py * formatted files to black==22.1.0 Co-authored-by: Wataru Hashimoto <[email protected]> Co-authored-by: Omar Sosa Rodríguez <[email protected]> Co-authored-by: Vedran Hadziosmanovic <[email protected]> Co-authored-by: Du Phan <[email protected]> Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]> Co-authored-by: austereantelope <[email protected]>
- Loading branch information