11from typing import Dict , Optional
22
3+ import numpy as np
4+
35import jax .numpy as jnp
46import lineax as lx
57
@@ -11,27 +13,84 @@ class LinearRegression(BaseEstimator):
1113 Linear regression model using lineax for efficient solving.
1214
1315 This class provides a simple interface for fitting a linear regression
14- model, especially useful for high-dimensional problems where n > p .
16+ model, especially useful for high-dimensional problems where p > n .
1517 """
1618
17- def __init__ (self ):
19+ def __init__ (self , solver = "lineax" ):
20+ """Initialize the LinearRegression model.
21+
22+ Args:
23+ solver (str, optional): Solver. Defaults to "lineax", can also be "jax" or "numpy".
24+ """
1825 super ().__init__ ()
26+ self .solver : str = solver
1927
20- def fit (self , X : jnp .ndarray , y : jnp .ndarray ) -> "LinearRegression" :
28+ def fit (
29+ self ,
30+ X : jnp .ndarray ,
31+ y : jnp .ndarray ,
32+ se : str = None ,
33+ ) -> "LinearRegression" :
2134 """
2235 Fit the linear model.
2336
2437 Args:
2538 X: The design matrix of shape (n_samples, n_features).
2639 y: The target vector of shape (n_samples,).
40+ se: Whether to compute standard errors. "HC1" for robust standard errors, "classical" for classical SEs.
2741
2842 Returns:
2943 The fitted estimator.
3044 """
31- sol = lx .linear_solve (
32- operator = lx .MatrixLinearOperator (X ),
33- vector = y ,
34- solver = lx .AutoLinearSolver (well_posed = None ),
35- )
36- self .params = {"beta" : sol .value }
45+
46+ if self .solver == "lineax" :
47+ sol = lx .linear_solve (
48+ operator = lx .MatrixLinearOperator (X ),
49+ vector = y ,
50+ solver = lx .AutoLinearSolver (well_posed = None ),
51+ # per lineax docs, passing well_posed None is remarkably general:
52+ # If the operator is non-square, then use lineax.QR. (Most likely case)
53+ # If the operator is diagonal, then use lineax.Diagonal.
54+ # If the operator is tridiagonal, then use lineax.Tridiagonal.
55+ # If the operator is triangular, then use lineax.Triangular.
56+ # If the matrix is positive or negative (semi-)definite, then use lineax.Cholesky.
57+ )
58+ self .params = {"coef" : sol .value }
59+
60+ elif self .solver == "jax" :
61+ sol = jnp .linalg .lstsq (X , y )
62+ self .params = {"coef" : sol [0 ]}
63+ elif self .solver == "numpy" : # for completeness
64+ X , y = np .array (X ), np .array (y )
65+ sol = np .linalg .lstsq (X , y , rcond = None )
66+ self .params = {"coef" : jnp .array (sol [0 ])}
67+
68+ if se :
69+ self ._vcov (
70+ y = y ,
71+ X = X ,
72+ se = se ,
73+ ) # set standard errors in params
3774 return self
75+
76+ def predict (self , X : jnp .ndarray ) -> jnp .ndarray :
77+ if not isinstance (X , jnp .ndarray ):
78+ X = jnp .array (X )
79+ return jnp .dot (X , self .params ["coef" ])
80+
81+ def _vcov (
82+ self ,
83+ y : jnp .ndarray ,
84+ X : jnp .ndarray ,
85+ se : str = "HC1" ,
86+ ) -> None :
87+ n , k = X .shape
88+ ε = y - X @ self .params ["coef" ]
89+ if se == "HC1" :
90+ M = jnp .einsum ("ij,i,ik->jk" , X , ε ** 2 , X ) # yer a wizard harry
91+ XtX = jnp .linalg .inv (X .T @ X )
92+ Σ = XtX @ M @ XtX
93+ self .params ["se" ] = jnp .sqrt ((n / (n - k )) * jnp .diag (Σ ))
94+ elif se == "classical" :
95+ XtX_inv = jnp .linalg .inv (X .T @ X )
96+ self .params ["se" ] = jnp .sqrt (jnp .diag (XtX_inv ) * jnp .var (ε , ddof = k ))
0 commit comments