Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: Install and run tests

on: [push, pull_request]

jobs:
ruff:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.13"

- name: Install Ruff
run: |
python -m pip install --upgrade pip
pip install ruff

- name: Run Ruff
run: ruff check . --output-format=github

pytest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install package and dependencies
run: |
python -m pip install --upgrade pip
pip install .

- name: Run pytest
run: |
pip install pytest
python -m pytest tests/ --junitxml=junit/test-results-${{ matrix.python-version }}.xml

- name: Upload pytest test results
uses: actions/upload-artifact@v4
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
if: ${{ always() }}
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2025 Benjamin Dodge and Philipp Frank

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
26 changes: 22 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
✅ generates Gaussian process realizations with approximately stationary, decaying kernels \
✅ scales to billions of parameters with linear time and memory requirements \
✅ effortlessly handles arbitrary point distributions with large dynamic range \
✅ uses JAX, with a faster CUDA extension that supports derivatives \
✅ uses JAX, with a faster custom CUDA extension that supports derivatives \
✅ has an exact inverse and determinant available

The underlying theory and implementation is described in two upcoming papers. It is an evolution of Iterative Charted Refinement [[1](https://arxiv.org/abs/2206.10634)], which was first implemented in the [NIFTy](https://pypi.org/project/nifty/) package. The tree algorithms are inspired by two GPU-friendly approaches [[2](https://arxiv.org/abs/2211.00120), [3](https://arxiv.org/abs/2210.12859)] originally implemented in the [cudaKDTree](https://github.com/ingowald/cudaKDTree) library.

We wrote this software for applications in astrophysics, but we hope others across the physical sciences will find it useful! Please do not hesitate to open an issue or discussion for questions and feedback :)

Authors: Benjamin Dodge, Philipp Frank


## Usage

Expand All @@ -18,8 +24,8 @@ kp, kx = jax.random.split(jax.random.key(99))
points = jax.random.normal(kp, shape=(100_000, 2))
xi = jax.random.normal(kx, shape=(100_000,))

graph = gp.build_graph(points, n0=1000, k=10)
covariance = gp.compute_matern_covariance_discrete(p=0, r_min=1e-3, r_max=1e3, n_bins=1_000)
graph = gp.build_graph(points, n0=100, k=10)
covariance = gp.extras.rbf_kernel(variance=1.0, scale=0.3, r_min=1e-4, r_max=1e1, n_bins=1_000, jitter=1e-4)
values = gp.generate(graph, covariance, xi)
```

Expand All @@ -28,9 +34,21 @@ To install, use pip. The only dependency is JAX.

```python -m pip install graphgp```

For large problems, it is recommended to install the custom CUDA extension as shown below, which will require CMake and the CUDA compiler (nvcc) installed on your system. It will take a moment to build and there may be rough edges.
For large problems, consider installing the custom CUDA extension as shown below, which will require CMake and the CUDA compiler (nvcc) installed on your system. It will take a moment to build and there may be rough edges, but memory and runtime requirements will be substantially improved. Please let us know if you encounter issues!

```python -m pip install graphgp[cuda]```


## Q&A
*How does it work?* \
The most straightforward way to generate a Gaussian Process realization at N arbitrary points is to construct a dense N x N covariance matrix, compute a matrix square root via Cholesky decomposition, and then apply it to a vector of white noise. This is equivalent to sequential generation of values, where each value is conditioned on all previous values using the Gaussian conditioning formulas. The main approximation made in GraphGP is to condition only on the values of k previously generated nearest neighbors. More details to come!

*Why am I getting NaNs?* \
Just as with a dense Cholesky decomposition, GraphGP can fail if the covariance matrix becomes singular due to finite precision arithmetic. For example, two points are so close together that their covariance is indistinguishable from their variance. A practical solution it to add "jitter" to the diagonal, as shown in the demo. Other options include reducing ``n0`` (singularity usually manifests in the dense Cholesky first), using 64-bit arithmetic, verifying that the covariance of the closest-spaced points can be represented for your choice of kernel, or increasing the number of bins for the discretized covariance. We are working to make this more user-friendly in the future.

*What is the difference between the pure JAX and custom CUDA versions?* \
The JAX version must store a (k+1) x (k+1) conditioning matrix for each point. The CUDA version generates these matrices on the fly and must only store the indices of k neighbors for each point. So we can expect roughly a factor of k better memory usage and runtime performance, depending on the exact setup.

@Philipp want to write a short paragraph explaining the context? Feel free to rephrase the question etc
*How does this work with Nifty?* \
GraphGP is not an inference package and will not help you fit your GP model to data. We encourage users to take advantage of Nifty's inference tools and GraphGP can serve as a drop-in replaement for ICR.
106 changes: 0 additions & 106 deletions basic.ipynb

This file was deleted.

Loading
Loading