Skip to content

A package for computing the inversion of the cyclogeostrophic balance based on a variational formulation approach, using JAX.

License

Notifications You must be signed in to change notification settings

meom-group/jaxparrow

Repository files navigation

jaxparrow

Python PyPi Tests Docs DOI

jaxparrow implements a novel approach based on a minimization-based formulation to compute the inversion of the cyclogeostrophic balance.

It leverages the power of JAX, to efficiently solve the inversion as a minimization problem. Given the Sea Surface Height (SSH) field of an ocean system, jaxparrow estimates the velocity field that best satisfies the cyclogeostrophic balance.

A comprehensive documenation is available: https://jaxparrow.readthedocs.io/en/latest/!

Installation

jaxparrow is Pip-installable:

pip install jaxparrow

However, users with access to GPUs or TPUs should first install JAX separately in order to fully benefit from its high-performance computing capacities. See JAX instructions. By default, jaxparrow will install a CPU-only version of JAX if no other version is already present in the Python environment.

Usage

The function you are most probably looking for is cyclogeostrophy. It computes the cyclogeostrophic velocity field (returned as two 2darray) from:

  • a SSH field (a 2darray),
  • the latitude and longitude grids at the T points (two 2darray).

In a Python script, assuming that the input grids have already been initialised / imported, estimating the cyclogeostrophic velocities for a single timestamp would resort to:

from jaxparrow import cyclogeostrophy

ucg, vcg = cyclogeostrophy(ssh_2d, lat_2d, lon_2d, return_grids=False)

Because jaxparrow uses C-grids the velocity fields are represented on two grids (U and V), and the tracer fields (such as SSH) on one grid (T). We provide functions computing some kinematics (such as velocities magnitude, normalized relative vorticity, or kinematic energy) accounting for these gridding system:

from jaxparrow.tools.kinematics import magnitude

uv_cg = magnitude(ucg, vcg)

To vectorise the estimation of the cyclogeostrophy along a first time dimension, one aims to use jax.vmap.

import jax

vmap_cyclogeostrophy = jax.vmap(cyclogeostrophy, in_axes=(0, None, None))
u_cg_3d, v_cg_3d = vmap_cyclogeostrophy(ssh_3d, lat_2d, lon_2d)

By default, the cyclogeostrophy function relies on our minimization-based method. Its method argument provides the ability to use the fixed-point method instead, as described by Penven et al. (2014). Additional arguments also give a finer control over the different approaches hyperparameters.

See jaxparrow documentation for more details (including the API description and step-by-step examples).

Contributing

Contributions are welcomed! See CONTRIBUTING.md and CONDUCT.md to get started.

How to cite

If you use this software, please cite it: CITATION.cff. Thank you!

About

A package for computing the inversion of the cyclogeostrophic balance based on a variational formulation approach, using JAX.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Languages