Skip to content

Commit 04e5a90

Browse files
committed
"make name consistent"
1 parent 4afa9fb commit 04e5a90

File tree

9 files changed

+135
-65
lines changed

9 files changed

+135
-65
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
*.blg
2828
*.out
2929
*.synctex.gz
30+
nb/scratch.ipynb

jaxonometrics/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .causal import EntropyBalancing, IPW, AIPW # Added IPW, AIPW
99
from .gmm import GMM, LinearIVGMM, TwoStepGMM
1010
from .linear import LinearRegression
11-
from .mle import Logit, PoissonRegression, MaximumLikelihoodEstimator # Added MLE models
11+
from .mle import LogisticRegression, PoissonRegression, MaximumLikelihoodEstimator # Added MLE models
1212

1313
__all__ = [
1414
"BaseEstimator",
@@ -20,6 +20,6 @@
2020
"TwoStepGMM",
2121
"LinearRegression",
2222
"MaximumLikelihoodEstimator",
23-
"Logit",
23+
"LogisticRegression",
2424
"PoissonRegression",
2525
]

jaxonometrics/causal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def __init__(self, propensity_optimizer: Optional[optax.GradientTransformation]
6767
ps_clip_epsilon: Small constant to clip propensity scores.
6868
"""
6969
super().__init__()
70-
from .mle import Logit # Local import
71-
self.logit_model = Logit(optimizer=propensity_optimizer, maxiter=propensity_maxiter)
70+
from .mle import LogisticRegression
71+
self.logit_model = LogisticRegression(optimizer=propensity_optimizer, maxiter=propensity_maxiter)
7272
self.ps_clip_epsilon = ps_clip_epsilon
7373
self.params: Dict[str, Any] = {"ate": None, "propensity_scores": None}
7474

@@ -168,11 +168,11 @@ def __init__(self,
168168
ps_clip_epsilon: Small constant to clip propensity scores to avoid extreme values.
169169
"""
170170
super().__init__()
171-
from .mle import Logit # Local import for Logit
171+
from .mle import LogisticRegression
172172

173173
self.outcome_model_template = outcome_model if outcome_model else LinearRegression()
174174
# We need two instances of the outcome model, one for T=1 and one for T=0
175-
self.propensity_model = propensity_model if propensity_model else Logit()
175+
self.propensity_model = propensity_model if propensity_model else LogisticRegression()
176176

177177
self.ps_clip_epsilon = ps_clip_epsilon
178178
self.params: Dict[str, Any] = {"ate": None, "propensity_scores": None,

jaxonometrics/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Dict, Optional
2+
from functools import partial
23

34
import numpy as np
45
import jax # Ensure jax is imported
@@ -9,7 +10,7 @@
910

1011

1112
# Helper function for JIT compilation of vcov calculations
12-
@jax.jit(static_argnames=['se_type', 'n', 'k']) # Mark se_type, n, and k as static
13+
@partial(jax.jit, static_argnames=['se_type', 'n', 'k']) # Mark se_type, n, and k as static
1314
def _calculate_vcov_details(
1415
coef: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray, se_type: str, n: int, k: int
1516
):

jaxonometrics/mle.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax.numpy as jnp
66
import optax
77

8+
89
from .base import BaseEstimator
910

1011

@@ -13,18 +14,28 @@ class MaximumLikelihoodEstimator(BaseEstimator):
1314
Base class for Maximum Likelihood Estimators using Optax.
1415
"""
1516

16-
def __init__(self, optimizer: Optional[optax.GradientTransformation] = None, maxiter: int = 5000, tol: float = 1e-4):
17+
def __init__(
18+
self,
19+
optimizer: Optional[optax.GradientTransformation] = None,
20+
maxiter: int = 5000,
21+
tol: float = 1e-4,
22+
):
1723
super().__init__()
18-
self.optimizer = optimizer if optimizer is not None else optax.adam(learning_rate=1e-3)
24+
self.optimizer = optimizer if optimizer is not None else optax.lbfgs()
1925
self.maxiter = maxiter
2026
# Tol is not directly used by basic optax loops for stopping but can be a reference
2127
# or used if a convergence check is manually added.
2228
self.tol = tol
23-
self.params: Dict[str, jnp.ndarray] = {} # Initialize params
24-
self.history: Dict[str, list] = {"loss": []} # To store loss history
29+
self.params: Dict[str, jnp.ndarray] = {} # Initialize params
30+
self.history: Dict[str, list] = {"loss": []} # To store loss history
2531

2632
@abstractmethod
27-
def _negative_log_likelihood(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray) -> float:
33+
def _negative_log_likelihood(
34+
self,
35+
params: jnp.ndarray,
36+
X: jnp.ndarray,
37+
y: jnp.ndarray,
38+
) -> float:
2839
"""
2940
Computes the negative log-likelihood for the model.
3041
Must be implemented by subclasses.
@@ -37,7 +48,13 @@ def _negative_log_likelihood(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.n
3748
"""
3849
raise NotImplementedError
3950

40-
def fit(self, X: jnp.ndarray, y: jnp.ndarray, init_params: Optional[jnp.ndarray] = None) -> "MaximumLikelihoodEstimator":
51+
def fit(
52+
self,
53+
X: jnp.ndarray,
54+
y: jnp.ndarray,
55+
init_params: Optional[jnp.ndarray] = None,
56+
verbose: bool = False,
57+
) -> "MaximumLikelihoodEstimator":
4158
"""
4259
Fit the model using the specified Optax optimizer.
4360
@@ -53,10 +70,10 @@ def fit(self, X: jnp.ndarray, y: jnp.ndarray, init_params: Optional[jnp.ndarray]
5370
"""
5471
n_features = X.shape[1]
5572
if init_params is None:
56-
try: # Try to use a key for initialization for better starting points
57-
key = jax.random.PRNGKey(0) # Simple fixed key for reproducibility
73+
try: # Try to use a key for initialization for better starting points
74+
key = jax.random.PRNGKey(0) # Simple fixed key for reproducibility
5875
init_params_val = jax.random.normal(key, (n_features,)) * 0.01
59-
except Exception: # Fallback if key generation fails or not in context
76+
except Exception: # Fallback if key generation fails or not in context
6077
init_params_val = jnp.zeros(n_features)
6178
else:
6279
init_params_val = init_params
@@ -67,31 +84,41 @@ def loss_fn(params_lg):
6784
return self._negative_log_likelihood(params_lg, X, y)
6885

6986
# Get the gradient function
70-
value_and_grad_fn = jax.value_and_grad(loss_fn)
87+
value_and_grad_fn = optax.value_and_grad_from_state(loss_fn)
7188

7289
# Initialize optimizer state
7390
opt_state = self.optimizer.init(init_params_val)
7491

7592
current_params = init_params_val
76-
self.history["loss"] = [] # Reset loss history
93+
self.history["loss"] = [] # Reset loss history
7794

7895
# Optimization loop
7996
for i in range(self.maxiter):
80-
loss_val, grads = value_and_grad_fn(current_params)
81-
updates, opt_state = self.optimizer.update(grads, opt_state, current_params)
82-
current_params = optax.apply_updates(current_params, updates)
97+
loss_val, grads = value_and_grad_fn(current_params, state=opt_state)
98+
updates, opt_state = self.optimizer.update(
99+
grads,
100+
opt_state,
101+
current_params,
102+
value=loss_val,
103+
grad=grads,
104+
value_fn=loss_fn,
105+
)
106+
current_params = optax.apply_updates(
107+
current_params,
108+
updates,
109+
)
83110
self.history["loss"].append(loss_val)
84-
85-
# Basic convergence check (optional, and might need adjustment)
86-
# This is a simple check on loss improvement. More robust checks might look at gradient norms or param changes.
87111
if i > 10 and self.tol > 0:
88-
loss_change = abs(self.history["loss"][-2] - self.history["loss"][-1]) / (abs(self.history["loss"][-2]) + 1e-8)
112+
loss_change = abs(
113+
self.history["loss"][-2] - self.history["loss"][-1]
114+
) / (abs(self.history["loss"][-2]) + 1e-8)
89115
if loss_change < self.tol:
90-
# print(f"Convergence tolerance {self.tol} met at iteration {i}.")
116+
if verbose:
117+
print(f"Convergence tolerance {self.tol} met at iteration {i}.")
91118
break
92119

93120
self.params = {"coef": current_params}
94-
self.iterations_run = i + 1 # Store how many iterations actually ran
121+
self.iterations_run = i + 1 # Store how many iterations actually ran
95122

96123
return self
97124

@@ -104,21 +131,28 @@ def summary(self) -> None:
104131
print(f"{self.__class__.__name__} Results")
105132
print("=" * 30)
106133
print(f"Optimizer: {self.optimizer}")
107-
if hasattr(self, 'iterations_run'):
108-
print(f"Optimization ran for {self.iterations_run}/{self.maxiter} iterations.")
134+
if hasattr(self, "iterations_run"):
135+
print(
136+
f"Optimization ran for {self.iterations_run}/{self.maxiter} iterations."
137+
)
109138
if self.history["loss"]:
110139
print(f"Final Loss: {self.history['loss'][-1]:.4e}")
111140

112141
print(f"Coefficients: {self.params['coef']}")
113142
print("=" * 30)
114143

115144

116-
class Logit(MaximumLikelihoodEstimator):
145+
class LogisticRegression(MaximumLikelihoodEstimator):
117146
"""
118147
Logistic Regression model.
119148
"""
120149

121-
def _negative_log_likelihood(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray) -> float:
150+
def _negative_log_likelihood(
151+
self,
152+
params: jnp.ndarray,
153+
X: jnp.ndarray,
154+
y: jnp.ndarray,
155+
) -> float:
122156
"""
123157
Computes the negative log-likelihood for logistic regression.
124158
NLL = -Σ [y_i * log(p_i) + (1 - y_i) * log(1 - p_i)]
@@ -128,15 +162,10 @@ def _negative_log_likelihood(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.n
128162
log(1-p_i) = log_sigmoid(-(X_i @ β))
129163
"""
130164
logits = X @ params
131-
# Using jax.nn.log_sigmoid for log(σ(z)) and log(1-σ(z)) = log(σ(-z))
132-
log_p = jax.nn.log_sigmoid(logits)
133-
log_one_minus_p = jax.nn.log_sigmoid(-logits) # log(1 - sigmoid(x)) = log(sigmoid(-x))
134-
135-
# Sum over samples
136-
nll = -jnp.sum(y * log_p + (1 - y) * log_one_minus_p)
137-
# Return mean NLL for more stable optimization across batch sizes,
138-
# though sum is also common. Your example used sum. Let's stick to sum for now.
139-
return nll # / X.shape[0] if averaging
165+
# alt: Using jax.nn.log_sigmoid for log(σ(z)) and log(1-σ(z)) = log(σ(-z))
166+
h = jax.scipy.special.expit(logits)
167+
nll = -jnp.sum(y * jnp.log(h) + (1 - y) * jnp.log1p(-h))
168+
return nll # / X.shape[0] if averaging
140169

141170
def predict_proba(self, X: jnp.ndarray) -> jnp.ndarray:
142171
"""
@@ -150,7 +179,7 @@ def predict_proba(self, X: jnp.ndarray) -> jnp.ndarray:
150179
raise ValueError("Model has not been fitted yet.")
151180

152181
logits = X @ self.params["coef"]
153-
return jax.nn.sigmoid(logits) # jax.scipy.special.expit is equivalent
182+
return jax.nn.sigmoid(logits) # jax.scipy.special.expit is equivalent
154183

155184
def predict(self, X: jnp.ndarray, threshold: float = 0.5) -> jnp.ndarray:
156185
"""
@@ -170,18 +199,23 @@ class PoissonRegression(MaximumLikelihoodEstimator):
170199
Poisson Regression model.
171200
"""
172201

173-
def _negative_log_likelihood(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray) -> float:
202+
def _negative_log_likelihood(
203+
self,
204+
params: jnp.ndarray,
205+
X: jnp.ndarray,
206+
y: jnp.ndarray,
207+
) -> float:
174208
"""
175209
Computes the negative log-likelihood for Poisson regression.
176210
The log(y_i!) term is constant w.r.t params, so ignored for optimization.
177211
NLL = Σ [exp(X_i @ β) - y_i * (X_i @ β)]
178212
"""
179213
linear_predictor = X @ params
180-
lambda_i = jnp.exp(linear_predictor) # Predicted rates
214+
lambda_i = jnp.exp(linear_predictor) # Predicted rates
181215

182216
# Sum over samples
183217
nll = jnp.sum(lambda_i - y * linear_predictor)
184-
return nll # / X.shape[0] if averaging
218+
return nll # / X.shape[0] if averaging
185219

186220
def predict(self, X: jnp.ndarray) -> jnp.ndarray:
187221
"""

nb/linmod.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"name": "stderr",
6565
"output_type": "stream",
6666
"text": [
67-
"INFO:2025-06-29 11:43:45,267:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n",
67+
"INFO:2025-06-29 15:46:55,666:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n",
6868
"INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n"
6969
]
7070
},
@@ -258,15 +258,15 @@
258258
},
259259
{
260260
"cell_type": "code",
261-
"execution_count": 8,
261+
"execution_count": 6,
262262
"id": "25b0053d",
263263
"metadata": {},
264264
"outputs": [
265265
{
266266
"name": "stdout",
267267
"output_type": "stream",
268268
"text": [
269-
"9.2 ms ± 1.24 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n"
269+
"8.98 ms ± 1.47 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n"
270270
]
271271
}
272272
],
@@ -287,15 +287,15 @@
287287
},
288288
{
289289
"cell_type": "code",
290-
"execution_count": 9,
290+
"execution_count": 7,
291291
"id": "051697e3",
292292
"metadata": {},
293293
"outputs": [
294294
{
295295
"name": "stdout",
296296
"output_type": "stream",
297297
"text": [
298-
"12.3 ms ± 1.57 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n"
298+
"10.9 ms ± 617 μs per loop (mean ± std. dev. of 5 runs, 100 loops each)\n"
299299
]
300300
}
301301
],

tests/test_causal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from jaxonometrics.causal import IPW, AIPW
88
from jaxonometrics.linear import LinearRegression
9-
from jaxonometrics.mle import Logit
9+
from jaxonometrics.mle import LogisticRegression
1010

1111
# Function to generate synthetic data for causal inference tests
1212
def generate_causal_data(n_samples=1000, n_features=3, true_ate=2.0, seed=42):
@@ -87,7 +87,7 @@ def test_aipw_ate_estimation(causal_sim_data):
8787
# Using default LinearRegression for outcome, Logit for propensity
8888
aipw_estimator = AIPW(
8989
outcome_model=LinearRegression(solver="lineax"), # Explicitly pass an instance
90-
propensity_model=Logit(maxiter=10000) # Explicitly pass an instance
90+
propensity_model=LogisticRegression(maxiter=10000) # Explicitly pass an instance
9191
)
9292
# X should include intercept for LinearRegression and Logit as currently implemented
9393
aipw_estimator.fit(X, T, y)
@@ -111,7 +111,7 @@ def test_aipw_with_custom_models(causal_sim_data):
111111
X, T, y, _, _, _, true_ate = causal_sim_data
112112

113113
# 1. Fit propensity score model
114-
ps_model = Logit(maxiter=10000)
114+
ps_model = LogisticRegression(maxiter=10000)
115115
ps_model.fit(X, T) # X includes intercept
116116

117117
# 2. Fit outcome models
@@ -137,7 +137,7 @@ def test_aipw_with_custom_models(causal_sim_data):
137137

138138
aipw_estimator = AIPW(
139139
outcome_model=LinearRegression(), # It will create new instances and fit
140-
propensity_model=Logit(maxiter=10000) # It will create a new instance and fit
140+
propensity_model=LogisticRegression(maxiter=10000) # It will create a new instance and fit
141141
)
142142
aipw_estimator.fit(X, T, y)
143143
estimated_ate = aipw_estimator.params["ate"]

tests/test_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ def test_linear_regression():
1919
X_with_intercept = jnp.c_[jnp.ones(X.shape[0]), X]
2020
jax_model = LinearRegression()
2121
jax_model.fit(X_with_intercept, jnp.array(y))
22-
jax_coef = jax_model.params["beta"][1:]
22+
jax_coef = jax_model.params["coef"][1:]
2323

2424
assert np.allclose(sklearn_coef, jax_coef, atol=1e-6)

0 commit comments

Comments
 (0)