Simple library that provides performant implementations of standard econometrics routines in the JAX ecosystem.
- jaxarrays everywhere
- lineaxfor solving linear systems
- jaxoptand- optaxfor numerical optimization (Levenberg–Marquardt for NNLS-type problems and SGD for larger problems)
- Linear Regression with multiple solver backends (lineax, JAX, numpy)
- Fixed Effects Regression with JAX-accelerated alternating projections
- GMM and IV Estimation
- Causal Inference (IPW, AIPW, Entropy Balancing)
- Maximum Likelihood Estimation (Logistic, Poisson)
jaxonometrics supports high-performance fixed effects regression with multiple FE variables:
from jaxonometrics import LinearRegression
import jax.numpy as jnp
# Your data
X = jnp.asarray(data)  # (n_obs, n_features)
y = jnp.asarray(target)  # (n_obs,)
firm_ids = jnp.asarray(firm_identifiers, dtype=jnp.int32)
year_ids = jnp.asarray(year_identifiers, dtype=jnp.int32)
# Two-way fixed effects
model = LinearRegression(solver="lineax")
model.fit(X, y, fe=[firm_ids, year_ids])
coefficients = model.params["coef"]uv pip install git+https://github.com/py-econometrics/jaxonometricsor clone the repository and install in editable mode.
Run the full test suite:
pytest tests/ -vRun only fixed effects tests:
pytest tests/ -m fe -vRun tests excluding slow ones:
pytest tests/ -m "not slow" -v