GPX is Gaussian Process Regression, written in JAX.
GPX currently supports:
- Standard GPR
- Sparse GPR (SGPR) in the Projected Processes Approximation
- SGPR in the Projected Processes Approximation, with landmark selection using the Randomly Pivoted Cholesky Decomposition
- Radial Basis Function Networks
- Training on target values or on derivative values (using the Hessian kernel)
- Kernels with automatic support for gradient and Hessian
- Dense and sparse operations, the latter of which are important to scale GP to large datasets.
- Iterative estimation of the log marginal likelihood with stochastic trace estimation and Lanczos quadrature.
- Interface to scipy, nlopt, and optax optimizers
An environment with python 3.10 is recommended. You can create it with conda, virtualenv, or pyenv.
Then simply clone the project and install it with pip.
For example, using conda:
conda create -n gpx-env python=3.10
conda activate gpx-env
git clone https://github.com/Molecolab-Pisa/GPX
cd GPX
pip install .If you need to install JAX with GPU support, install JAX first following the instructions provided by JAX.
You may want to look at our list of examples:
- GPR
- SGPR
- SGPR with RPCholesky
- GPR with derivatives
- Simple Multioutput GP
- Interface to NLOpt
- Kernelizers and Kernel Operations
- Maximum a Posteriori estimate
- Model Persistence in GPX
- Kernel derivatives
In order to cite GPX you can use the following bibtex entry:
@software{gpx2023github,
author = {Edoardo Cignoni and Patrizia Mazzeo and Amanda Arcidiacono and Lorenzo Cupellini and Benedetta Mennucci},
title = {GPX: Gaussian Process Regression in JAX},
url = {https://github.com/Molecolab-Pisa/GPX},
version = {0.1.0},
year = {2023},
}
