55import jax .numpy as jnp
66import optax
77
8+
89from .base import BaseEstimator
910
1011
@@ -13,18 +14,28 @@ class MaximumLikelihoodEstimator(BaseEstimator):
1314 Base class for Maximum Likelihood Estimators using Optax.
1415 """
1516
16- def __init__ (self , optimizer : Optional [optax .GradientTransformation ] = None , maxiter : int = 5000 , tol : float = 1e-4 ):
17+ def __init__ (
18+ self ,
19+ optimizer : Optional [optax .GradientTransformation ] = None ,
20+ maxiter : int = 5000 ,
21+ tol : float = 1e-4 ,
22+ ):
1723 super ().__init__ ()
18- self .optimizer = optimizer if optimizer is not None else optax .adam ( learning_rate = 1e-3 )
24+ self .optimizer = optimizer if optimizer is not None else optax .lbfgs ( )
1925 self .maxiter = maxiter
2026 # Tol is not directly used by basic optax loops for stopping but can be a reference
2127 # or used if a convergence check is manually added.
2228 self .tol = tol
23- self .params : Dict [str , jnp .ndarray ] = {} # Initialize params
24- self .history : Dict [str , list ] = {"loss" : []} # To store loss history
29+ self .params : Dict [str , jnp .ndarray ] = {} # Initialize params
30+ self .history : Dict [str , list ] = {"loss" : []} # To store loss history
2531
2632 @abstractmethod
27- def _negative_log_likelihood (self , params : jnp .ndarray , X : jnp .ndarray , y : jnp .ndarray ) -> float :
33+ def _negative_log_likelihood (
34+ self ,
35+ params : jnp .ndarray ,
36+ X : jnp .ndarray ,
37+ y : jnp .ndarray ,
38+ ) -> float :
2839 """
2940 Computes the negative log-likelihood for the model.
3041 Must be implemented by subclasses.
@@ -37,7 +48,13 @@ def _negative_log_likelihood(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.n
3748 """
3849 raise NotImplementedError
3950
40- def fit (self , X : jnp .ndarray , y : jnp .ndarray , init_params : Optional [jnp .ndarray ] = None ) -> "MaximumLikelihoodEstimator" :
51+ def fit (
52+ self ,
53+ X : jnp .ndarray ,
54+ y : jnp .ndarray ,
55+ init_params : Optional [jnp .ndarray ] = None ,
56+ verbose : bool = False ,
57+ ) -> "MaximumLikelihoodEstimator" :
4158 """
4259 Fit the model using the specified Optax optimizer.
4360
@@ -53,10 +70,10 @@ def fit(self, X: jnp.ndarray, y: jnp.ndarray, init_params: Optional[jnp.ndarray]
5370 """
5471 n_features = X .shape [1 ]
5572 if init_params is None :
56- try : # Try to use a key for initialization for better starting points
57- key = jax .random .PRNGKey (0 ) # Simple fixed key for reproducibility
73+ try : # Try to use a key for initialization for better starting points
74+ key = jax .random .PRNGKey (0 ) # Simple fixed key for reproducibility
5875 init_params_val = jax .random .normal (key , (n_features ,)) * 0.01
59- except Exception : # Fallback if key generation fails or not in context
76+ except Exception : # Fallback if key generation fails or not in context
6077 init_params_val = jnp .zeros (n_features )
6178 else :
6279 init_params_val = init_params
@@ -67,31 +84,41 @@ def loss_fn(params_lg):
6784 return self ._negative_log_likelihood (params_lg , X , y )
6885
6986 # Get the gradient function
70- value_and_grad_fn = jax . value_and_grad (loss_fn )
87+ value_and_grad_fn = optax . value_and_grad_from_state (loss_fn )
7188
7289 # Initialize optimizer state
7390 opt_state = self .optimizer .init (init_params_val )
7491
7592 current_params = init_params_val
76- self .history ["loss" ] = [] # Reset loss history
93+ self .history ["loss" ] = [] # Reset loss history
7794
7895 # Optimization loop
7996 for i in range (self .maxiter ):
80- loss_val , grads = value_and_grad_fn (current_params )
81- updates , opt_state = self .optimizer .update (grads , opt_state , current_params )
82- current_params = optax .apply_updates (current_params , updates )
97+ loss_val , grads = value_and_grad_fn (current_params , state = opt_state )
98+ updates , opt_state = self .optimizer .update (
99+ grads ,
100+ opt_state ,
101+ current_params ,
102+ value = loss_val ,
103+ grad = grads ,
104+ value_fn = loss_fn ,
105+ )
106+ current_params = optax .apply_updates (
107+ current_params ,
108+ updates ,
109+ )
83110 self .history ["loss" ].append (loss_val )
84-
85- # Basic convergence check (optional, and might need adjustment)
86- # This is a simple check on loss improvement. More robust checks might look at gradient norms or param changes.
87111 if i > 10 and self .tol > 0 :
88- loss_change = abs (self .history ["loss" ][- 2 ] - self .history ["loss" ][- 1 ]) / (abs (self .history ["loss" ][- 2 ]) + 1e-8 )
112+ loss_change = abs (
113+ self .history ["loss" ][- 2 ] - self .history ["loss" ][- 1 ]
114+ ) / (abs (self .history ["loss" ][- 2 ]) + 1e-8 )
89115 if loss_change < self .tol :
90- # print(f"Convergence tolerance {self.tol} met at iteration {i}.")
116+ if verbose :
117+ print (f"Convergence tolerance { self .tol } met at iteration { i } ." )
91118 break
92119
93120 self .params = {"coef" : current_params }
94- self .iterations_run = i + 1 # Store how many iterations actually ran
121+ self .iterations_run = i + 1 # Store how many iterations actually ran
95122
96123 return self
97124
@@ -104,21 +131,28 @@ def summary(self) -> None:
104131 print (f"{ self .__class__ .__name__ } Results" )
105132 print ("=" * 30 )
106133 print (f"Optimizer: { self .optimizer } " )
107- if hasattr (self , 'iterations_run' ):
108- print (f"Optimization ran for { self .iterations_run } /{ self .maxiter } iterations." )
134+ if hasattr (self , "iterations_run" ):
135+ print (
136+ f"Optimization ran for { self .iterations_run } /{ self .maxiter } iterations."
137+ )
109138 if self .history ["loss" ]:
110139 print (f"Final Loss: { self .history ['loss' ][- 1 ]:.4e} " )
111140
112141 print (f"Coefficients: { self .params ['coef' ]} " )
113142 print ("=" * 30 )
114143
115144
116- class Logit (MaximumLikelihoodEstimator ):
145+ class LogisticRegression (MaximumLikelihoodEstimator ):
117146 """
118147 Logistic Regression model.
119148 """
120149
121- def _negative_log_likelihood (self , params : jnp .ndarray , X : jnp .ndarray , y : jnp .ndarray ) -> float :
150+ def _negative_log_likelihood (
151+ self ,
152+ params : jnp .ndarray ,
153+ X : jnp .ndarray ,
154+ y : jnp .ndarray ,
155+ ) -> float :
122156 """
123157 Computes the negative log-likelihood for logistic regression.
124158 NLL = -Σ [y_i * log(p_i) + (1 - y_i) * log(1 - p_i)]
@@ -128,15 +162,10 @@ def _negative_log_likelihood(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.n
128162 log(1-p_i) = log_sigmoid(-(X_i @ β))
129163 """
130164 logits = X @ params
131- # Using jax.nn.log_sigmoid for log(σ(z)) and log(1-σ(z)) = log(σ(-z))
132- log_p = jax .nn .log_sigmoid (logits )
133- log_one_minus_p = jax .nn .log_sigmoid (- logits ) # log(1 - sigmoid(x)) = log(sigmoid(-x))
134-
135- # Sum over samples
136- nll = - jnp .sum (y * log_p + (1 - y ) * log_one_minus_p )
137- # Return mean NLL for more stable optimization across batch sizes,
138- # though sum is also common. Your example used sum. Let's stick to sum for now.
139- return nll # / X.shape[0] if averaging
165+ # alt: Using jax.nn.log_sigmoid for log(σ(z)) and log(1-σ(z)) = log(σ(-z))
166+ h = jax .scipy .special .expit (logits )
167+ nll = - jnp .sum (y * jnp .log (h ) + (1 - y ) * jnp .log1p (- h ))
168+ return nll # / X.shape[0] if averaging
140169
141170 def predict_proba (self , X : jnp .ndarray ) -> jnp .ndarray :
142171 """
@@ -150,7 +179,7 @@ def predict_proba(self, X: jnp.ndarray) -> jnp.ndarray:
150179 raise ValueError ("Model has not been fitted yet." )
151180
152181 logits = X @ self .params ["coef" ]
153- return jax .nn .sigmoid (logits ) # jax.scipy.special.expit is equivalent
182+ return jax .nn .sigmoid (logits ) # jax.scipy.special.expit is equivalent
154183
155184 def predict (self , X : jnp .ndarray , threshold : float = 0.5 ) -> jnp .ndarray :
156185 """
@@ -170,18 +199,23 @@ class PoissonRegression(MaximumLikelihoodEstimator):
170199 Poisson Regression model.
171200 """
172201
173- def _negative_log_likelihood (self , params : jnp .ndarray , X : jnp .ndarray , y : jnp .ndarray ) -> float :
202+ def _negative_log_likelihood (
203+ self ,
204+ params : jnp .ndarray ,
205+ X : jnp .ndarray ,
206+ y : jnp .ndarray ,
207+ ) -> float :
174208 """
175209 Computes the negative log-likelihood for Poisson regression.
176210 The log(y_i!) term is constant w.r.t params, so ignored for optimization.
177211 NLL = Σ [exp(X_i @ β) - y_i * (X_i @ β)]
178212 """
179213 linear_predictor = X @ params
180- lambda_i = jnp .exp (linear_predictor ) # Predicted rates
214+ lambda_i = jnp .exp (linear_predictor ) # Predicted rates
181215
182216 # Sum over samples
183217 nll = jnp .sum (lambda_i - y * linear_predictor )
184- return nll # / X.shape[0] if averaging
218+ return nll # / X.shape[0] if averaging
185219
186220 def predict (self , X : jnp .ndarray ) -> jnp .ndarray :
187221 """
0 commit comments