Skip to content

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

License

Notifications You must be signed in to change notification settings

ahuang314/JAXtronomy

 
 

Repository files navigation

JAXtronomy

https://codecov.io/gh/lenstronomy/JAXtronomy/graph/badge.svg?token=6EJAX8CF62 https://img.shields.io/pypi/v/JAXtronomy?label=PyPI&logo=pypi

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

The goal of this library is to reimplement lenstronomy functionalities in pure JAX to allow for automatic differentiation, GPU acceleration, and batched computations.

Guiding Principles:

  • Strive to be a drop-in replacement for lenstronomy, i.e. provide a close match to the lenstronomy API.
  • Each function/feature will be tested against the reference lenstronomy implementation.
  • This package will aim to be a subset of lenstronomy (i.e. only contains functions with a reference lenstronomy implementation).
  • Implementations should be easy to read and understand.
  • Code should be pip installable on any machine, no compilation required.
  • Any notable differences between the JAX and reference implementations will be clearly documented.

Installation:

JAXtronomy can be installed with

pip install jaxtronomy

Performance comparison between JAXtronomy and lenstronomy

We compare the runtimes between JAXtronomy and lenstronomy by timing 10,000 function executions. While lenstronomy is always run using CPU, JAXtronomy can be run using either CPU or GPU.

LensModel ray-shooting

The table below shows how much faster JAXtronomy is compared to lenstronomy for different deflector profiles and different grid sizes. These tests were run using an Intel(R) Xeon(R) Gold 6338 CPU @ 2.00GHz and an NVIDIA A100 GPU.

Deflector Profile 60x60 grid (JAX w/ cpu) 180x180 grid (JAX w/ cpu) 180x180 grid (JAX w/ gpu)
CONVERGENCE 0.4x 1.3x 0.4x
CSE 1.6x 2.9x 2.3x
EPL 1.1x - 15x 1.6x - 17x 6.4x - 120x
EPL (jax) vs EPL_NUMBA 1.8x 3.2x 13x
EPL_MULTIPOLE_M1M3M4 1.1x - 7x 3.3x - 13x 22x - 108x
EPL_MULTIPOLE_M1M3M4_ELL 1.2x - 3.3x 1.1x - 3.3x 18x - 140x
GAUSSIAN 1.0x 1.8x 3.0x
GAUSSIAN_POTENTIAL 0.9x 1.7x 2.4x
HERNQUIST 1.9x 3.6x 6.4x
HERNQUIST_ELLIPSE_CSE 3.8x 5.9x 40.3x
MULTIPOLE 0.9x 1.0x 10.0x
MULTIPOLE_ELL (m=1, m=3, m=4) 1.5x, 1.5x, 2.1x 1.3x, 1.3x, 1.9x 92x, 92x, 90x
NFW 1.6x 3.3x 5.0x
NFW_ELLIPSE_CSE 4.1x 5.7x 36.5x
NIE 0.5x 0.5x 2.0x
PJAFFE 1.0x 1.2x 3.0x
PJAFFE_ELLIPSE_POTENTIAL 1.5x 1.6x 3.1x
SHEAR 0.7x 2.2x 1.0x
SIE 0.5x 0.5x 2.0x
SIS 1.4x 3.0x 2.0x
TNFW 2.4x 5.4x 8.3x

Note that some profiles' runtime may be dependent on function arguments, such as the EPL profile which involves performing a hyp2f1 calculation using a power series expansion. The more terms involved, the better JAXtronomy performs compared to lenstronomy.

A performance comparison notebook is available for more detailed analysis.

LightModel surface brightness

The table below shows how much faster JAXtronomy is compared to lenstronomy for different source profiles and different grid sizes. These tests were run using an Intel(R) Xeon(R) Gold 6338 CPU @ 2.00GHz and an NVIDIA A100 GPU.

Source Profile 60x60 grid (JAX w/ cpu) 180x180 grid (JAX w/ cpu) 180x180 grid (JAX w/ gpu)
CORE_SERSIC 2.1x 10.2x 4.4x
GAUSSIAN 1.6x 3.4x 1.6x
GAUSSIAN_ELLIPSE 1.5x 6.9x 2.1x
MULTI_GAUSSIAN (5 components) 3.7x 16.2x 7.8x
MULTI_GAUSSIAN_ELLIPSE (5 components) 4.4x 18.3x 7.2x
SERSIC 2.3x 9.3x 4.2x
SERSIC_ELLIPSE 2.1x 8.5x 3.2x
SERSIC_ELLIPSE_Q_PHI 1.7x 8.6x 3.4x
SHAPELETS (n_max=6) 8.0x 5.2x 17.6x
SHAPELETS (n_max=10) 8.9x 6.1x 22.4x

FFT Pixel Kernel Convolution

Convolution runtimes vary significantly, depending on both grid size and kernel size. A short summary is as follows, using an Intel(R) Xeon(R) Gold 6338 CPU @ 2.00GHz and an NVIDIA A100 GPU.

  • For a 60x60 grid, and kernel sizes ranging from 3 to 45, jaxtronomy on CPU is about 1.1x to 2.9x faster than lenstronomy, with no obvious correlation to kernel size.
  • For a 60x60 grid, and kernel sizes ranging from 3 to 45, jaxtronomy on GPU is about 1.5x to 3.5x faster than lenstronomy, with JAX performing better with higher kernel sizes.
  • For a 180x180 grid, and kernel sizes ranging from 9 to 135, jaxtronomy on CPU is about 0.7x to 2.5x as fast as lenstronomy, with no obvious correlation to kernel size.
  • For a 180x180 grid, and kernel sizes ranging from 9 t0 135, jaxtronomy on GPU is about 10x to 20x as fast as lenstronomy, with JAX performing better with higher kernel sizes.

A performance comparison notebook is available for more detailed analysis.

Related software packages

The following lensing software packages do use JAX-accelerated computing that in part were inspired or made use of lenstronomy functions:

About

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 81.5%
  • Jupyter Notebook 18.4%
  • Makefile 0.1%