Skip to content

Commit 6f568c5

Browse files
committed
add standard errors implementation, clean up notebook
1 parent bbbb0fc commit 6f568c5

File tree

3 files changed

+391
-211
lines changed

3 files changed

+391
-211
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
*/__pycache__/*
1+
*/**/__pycache__/*
22
*.csv
33
*.dta
44
*.xlsx

jaxonometrics/linear.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Dict, Optional
22

3+
import numpy as np
4+
35
import jax.numpy as jnp
46
import lineax as lx
57

@@ -11,27 +13,84 @@ class LinearRegression(BaseEstimator):
1113
Linear regression model using lineax for efficient solving.
1214
1315
This class provides a simple interface for fitting a linear regression
14-
model, especially useful for high-dimensional problems where n > p.
16+
model, especially useful for high-dimensional problems where p > n.
1517
"""
1618

17-
def __init__(self):
19+
def __init__(self, solver="lineax"):
20+
"""Initialize the LinearRegression model.
21+
22+
Args:
23+
solver (str, optional): Solver. Defaults to "lineax", can also be "jax" or "numpy".
24+
"""
1825
super().__init__()
26+
self.solver: str = solver
1927

20-
def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> "LinearRegression":
28+
def fit(
29+
self,
30+
X: jnp.ndarray,
31+
y: jnp.ndarray,
32+
se: str = None,
33+
) -> "LinearRegression":
2134
"""
2235
Fit the linear model.
2336
2437
Args:
2538
X: The design matrix of shape (n_samples, n_features).
2639
y: The target vector of shape (n_samples,).
40+
se: Whether to compute standard errors. "HC1" for robust standard errors, "classical" for classical SEs.
2741
2842
Returns:
2943
The fitted estimator.
3044
"""
31-
sol = lx.linear_solve(
32-
operator=lx.MatrixLinearOperator(X),
33-
vector=y,
34-
solver=lx.AutoLinearSolver(well_posed=None),
35-
)
36-
self.params = {"beta": sol.value}
45+
46+
if self.solver == "lineax":
47+
sol = lx.linear_solve(
48+
operator=lx.MatrixLinearOperator(X),
49+
vector=y,
50+
solver=lx.AutoLinearSolver(well_posed=None),
51+
# per lineax docs, passing well_posed None is remarkably general:
52+
# If the operator is non-square, then use lineax.QR. (Most likely case)
53+
# If the operator is diagonal, then use lineax.Diagonal.
54+
# If the operator is tridiagonal, then use lineax.Tridiagonal.
55+
# If the operator is triangular, then use lineax.Triangular.
56+
# If the matrix is positive or negative (semi-)definite, then use lineax.Cholesky.
57+
)
58+
self.params = {"coef": sol.value}
59+
60+
elif self.solver == "jax":
61+
sol = jnp.linalg.lstsq(X, y)
62+
self.params = {"coef": sol[0]}
63+
elif self.solver == "numpy": # for completeness
64+
X, y = np.array(X), np.array(y)
65+
sol = np.linalg.lstsq(X, y, rcond=None)
66+
self.params = {"coef": jnp.array(sol[0])}
67+
68+
if se:
69+
self._vcov(
70+
y=y,
71+
X=X,
72+
se=se,
73+
) # set standard errors in params
3774
return self
75+
76+
def predict(self, X: jnp.ndarray) -> jnp.ndarray:
77+
if not isinstance(X, jnp.ndarray):
78+
X = jnp.array(X)
79+
return jnp.dot(X, self.params["coef"])
80+
81+
def _vcov(
82+
self,
83+
y: jnp.ndarray,
84+
X: jnp.ndarray,
85+
se: str = "HC1",
86+
) -> None:
87+
n, k = X.shape
88+
ε = y - X @ self.params["coef"]
89+
if se == "HC1":
90+
M = jnp.einsum("ij,i,ik->jk", X, ε**2, X) # yer a wizard harry
91+
XtX = jnp.linalg.inv(X.T @ X)
92+
Σ = XtX @ M @ XtX
93+
self.params["se"] = jnp.sqrt((n / (n - k)) * jnp.diag(Σ))
94+
elif se == "classical":
95+
XtX_inv = jnp.linalg.inv(X.T @ X)
96+
self.params["se"] = jnp.sqrt(jnp.diag(XtX_inv) * jnp.var(ε, ddof=k))

0 commit comments

Comments
 (0)