Static Analysis of Automatic Differentiation via Dual-Intervals on JAX jaxpr
This project implements dual-interval arithmetic for statically analyzing programs represented by the jaxpr (JAX expression) intermediate representation. The analysis provides sound overapproximations of both value and gradient bounds.
To recreate our visualization graphs, use the command line:
python visualize_bounds.py
python visualize_gpt2_bounds.py