Skip to content

Commit 68022a0

Browse files
committed
"use better notation, leaner deps"
1 parent 04e5a90 commit 68022a0

File tree

3 files changed

+146
-84
lines changed

3 files changed

+146
-84
lines changed

jaxonometrics/causal.py

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Dict, Optional, Any
22

3-
import jax # Ensure jax is imported
4-
import optax # Added for type hint
3+
import jax # Ensure jax is imported
4+
import optax # Added for type hint
55
import jax.numpy as jnp
66
from jaxopt import LBFGS
77

88
from .base import BaseEstimator
9+
from .linear import LinearRegression # For default outcome model in AIPW
910

1011

1112
class EntropyBalancing(BaseEstimator):
@@ -55,8 +56,12 @@ class IPW(BaseEstimator):
5556
This implementation uses Logistic Regression for propensity score estimation.
5657
"""
5758

58-
def __init__(self, propensity_optimizer: Optional[optax.GradientTransformation] = None,
59-
propensity_maxiter: int = 5000, ps_clip_epsilon: float = 1e-6):
59+
def __init__(
60+
self,
61+
propensity_optimizer: Optional[optax.GradientTransformation] = None,
62+
propensity_maxiter: int = 5000,
63+
ps_clip_epsilon: float = 1e-6,
64+
):
6065
"""
6166
Initialize the IPW estimator.
6267
@@ -68,12 +73,19 @@ def __init__(self, propensity_optimizer: Optional[optax.GradientTransformation]
6873
"""
6974
super().__init__()
7075
from .mle import LogisticRegression
71-
self.logit_model = LogisticRegression(optimizer=propensity_optimizer, maxiter=propensity_maxiter)
76+
77+
self.logit_model = LogisticRegression(
78+
optimizer=propensity_optimizer, maxiter=propensity_maxiter
79+
)
7280
self.ps_clip_epsilon = ps_clip_epsilon
7381
self.params: Dict[str, Any] = {"ate": None, "propensity_scores": None}
7482

75-
76-
def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "IPW":
83+
def fit(
84+
self,
85+
X: jnp.ndarray,
86+
W: jnp.ndarray,
87+
y: jnp.ndarray,
88+
) -> "IPW":
7789
"""
7890
Estimate the Average Treatment Effect (ATE) using IPW.
7991
@@ -87,24 +99,26 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "IPW":
8799
The fitted estimator with ATE and propensity scores.
88100
"""
89101
# Ensure T is jnp.ndarray for Logit model
90-
if not isinstance(T, jnp.ndarray):
91-
T_jax = jnp.array(T)
102+
if not isinstance(W, jnp.ndarray):
103+
W_jax = jnp.array(W)
92104
else:
93-
T_jax = T
105+
W_jax = W
94106

95107
# 1. Estimate propensity scores P(T=1|X) using Logit
96-
self.logit_model.fit(X, T_jax)
108+
self.logit_model.fit(X, W_jax)
97109
propensity_scores = self.logit_model.predict_proba(X)
98110

99111
# Clip propensity scores using the instance attribute
100-
propensity_scores = jnp.clip(propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon)
112+
propensity_scores = jnp.clip(
113+
propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon
114+
)
101115

102116
self.params["propensity_scores"] = propensity_scores
103117

104118
# 2. Calculate IPW weights
105119
# Weight for treated: 1 / p_score
106120
# Weight for control: 1 / (1 - p_score)
107-
weights = T_jax / propensity_scores + (1 - T_jax) / (1 - propensity_scores)
121+
weights = W_jax / propensity_scores + (1 - W_jax) / (1 - propensity_scores)
108122

109123
# 3. Estimate ATE: E[Y(1)] - E[Y(0)]
110124
# E[Y(1)] = sum(T_i * y_i / p_i) / sum(T_i / p_i)
@@ -120,8 +134,8 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "IPW":
120134
# This can also be seen as E[ (T - e)Y / (e(1-e)) ]
121135
# However, the difference of means of weighted outcomes is more standard:
122136

123-
mean_y1 = jnp.sum(T_jax * y * weights) / jnp.sum(T_jax * weights)
124-
mean_y0 = jnp.sum((1 - T_jax) * y * weights) / jnp.sum((1 - T_jax) * weights)
137+
mean_y1 = jnp.sum(W_jax * y * weights) / jnp.sum(W_jax * weights)
138+
mean_y0 = jnp.sum((1 - W_jax) * y * weights) / jnp.sum((1 - W_jax) * weights)
125139

126140
# The above is equivalent to:
127141
# mean_y1 = jnp.sum( (T_jax * y) / propensity_scores ) / jnp.sum( T_jax / propensity_scores )
@@ -133,28 +147,36 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "IPW":
133147
return self
134148

135149
def summary(self) -> None:
136-
super().summary() # Calls BaseEstimator summary
150+
super().summary() # Calls BaseEstimator summary
137151
if self.params and "ate" in self.params and self.params["ate"] is not None:
138152
print(f" Estimated ATE: {self.params['ate']:.4f}")
139-
if self.params and "propensity_scores" in self.params and self.params["propensity_scores"] is not None:
140-
print(f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}")
153+
if (
154+
self.params
155+
and "propensity_scores" in self.params
156+
and self.params["propensity_scores"] is not None
157+
):
158+
print(
159+
f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}"
160+
)
161+
141162

142163
# Need to add `Any` to imports for type hinting
143164
# from typing import Dict, Optional, Any
144165
# Need to add this at the top of causal.py
145166

146-
from .linear import LinearRegression # For default outcome model in AIPW
147167

148168
class AIPW(BaseEstimator):
149169
"""
150170
Augmented Inverse Propensity Weighting (AIPW) estimator for ATE.
151171
Also known as doubly robust estimator.
152172
"""
153173

154-
def __init__(self,
155-
outcome_model: Optional[BaseEstimator] = None,
156-
propensity_model: Optional[Any] = None, # Should be a Logit instance or similar
157-
ps_clip_epsilon: float = 1e-6):
174+
def __init__(
175+
self,
176+
outcome_model: Optional[BaseEstimator] = None,
177+
propensity_model: Optional[Any] = None, # Should be a Logit instance or similar
178+
ps_clip_epsilon: float = 1e-6,
179+
):
158180
"""
159181
Initialize the AIPW estimator.
160182
@@ -170,15 +192,28 @@ def __init__(self,
170192
super().__init__()
171193
from .mle import LogisticRegression
172194

173-
self.outcome_model_template = outcome_model if outcome_model else LinearRegression()
195+
self.outcome_model_template = (
196+
outcome_model if outcome_model else LinearRegression()
197+
)
174198
# We need two instances of the outcome model, one for T=1 and one for T=0
175-
self.propensity_model = propensity_model if propensity_model else LogisticRegression()
199+
self.propensity_model = (
200+
propensity_model if propensity_model else LogisticRegression()
201+
)
176202

177203
self.ps_clip_epsilon = ps_clip_epsilon
178-
self.params: Dict[str, Any] = {"ate": None, "propensity_scores": None,
179-
"mu0_params": None, "mu1_params": None}
204+
self.params: Dict[str, Any] = {
205+
"ate": None,
206+
"propensity_scores": None,
207+
"mu0_params": None,
208+
"mu1_params": None,
209+
}
180210

181-
def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "AIPW":
211+
def fit(
212+
self,
213+
X: jnp.ndarray,
214+
W: jnp.ndarray,
215+
y: jnp.ndarray,
216+
) -> "AIPW":
182217
"""
183218
Estimate the Average Treatment Effect (ATE) using AIPW.
184219
@@ -191,58 +226,67 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "AIPW":
191226
Returns:
192227
The fitted estimator with ATE.
193228
"""
194-
if not isinstance(T, jnp.ndarray): T_jax = jnp.array(T)
195-
else: T_jax = T
196-
if not isinstance(y, jnp.ndarray): y_jax = jnp.array(y)
197-
else: y_jax = y
198-
if not isinstance(X, jnp.ndarray): X_jax = jnp.array(X)
199-
else: X_jax = X
229+
if not isinstance(W, jnp.ndarray):
230+
W_jax = jnp.array(W)
231+
else:
232+
W_jax = W
233+
if not isinstance(y, jnp.ndarray):
234+
y_jax = jnp.array(y)
235+
else:
236+
y_jax = y
237+
if not isinstance(X, jnp.ndarray):
238+
X_jax = jnp.array(X)
239+
else:
240+
X_jax = X
200241

201242
n_samples = X_jax.shape[0]
202243

203244
# 1. Estimate propensity scores P(T=1|X) = e(X)
204-
self.propensity_model.fit(X_jax, T_jax)
245+
self.propensity_model.fit(X_jax, W_jax)
205246
propensity_scores = self.propensity_model.predict_proba(X_jax)
206-
propensity_scores = jnp.clip(propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon)
247+
propensity_scores = jnp.clip(
248+
propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon
249+
)
207250
self.params["propensity_scores"] = propensity_scores
208251

209252
# 2. Estimate outcome models E[Y|X, T=1] = μ_1(X) and E[Y|X, T=0] = μ_0(X)
210253
# Need to handle potential issues if one group has no samples (though unlikely with real data)
211-
X_treated = X_jax[T_jax == 1]
212-
y_treated = y_jax[T_jax == 1]
213-
X_control = X_jax[T_jax == 0]
214-
y_control = y_jax[T_jax == 0]
254+
X_treated = X_jax[W_jax == 1]
255+
y_treated = y_jax[W_jax == 1]
256+
X_control = X_jax[W_jax == 0]
257+
y_control = y_jax[W_jax == 0]
215258

216259
# Create fresh instances of the outcome model for fitting
217260
# This assumes the outcome_model_template can be re-used (e.g. by creating a new instance or being stateless after fit)
218261
# For sklearn-like models, this means creating new instances.
219262
# For our JAX models, they are re-fitted.
220263

221-
model1 = self.outcome_model_template.__class__() # Create a new instance of the same type
264+
model1 = (
265+
self.outcome_model_template.__class__()
266+
) # Create a new instance of the same type
222267
if X_treated.shape[0] > 0:
223268
model1.fit(X_treated, y_treated)
224269
mu1_X = model1.predict(X_jax)
225270
self.params["mu1_params"] = model1.params
226-
else: # Should not happen in typical scenarios
271+
else: # Should not happen in typical scenarios
227272
mu1_X = jnp.zeros(n_samples)
228273
self.params["mu1_params"] = None
229274

230-
model0 = self.outcome_model_template.__class__() # Create a new instance
275+
model0 = self.outcome_model_template.__class__() # Create a new instance
231276
if X_control.shape[0] > 0:
232277
model0.fit(X_control, y_control)
233278
mu0_X = model0.predict(X_jax)
234279
self.params["mu0_params"] = model0.params
235-
else: # Should not happen
280+
else: # Should not happen
236281
mu0_X = jnp.zeros(n_samples)
237282
self.params["mu0_params"] = None
238283

239-
240284
# 3. Calculate AIPW estimator components
241285
# ψ_i = μ_1(X_i) - μ_0(X_i) + T_i/e(X_i) * (Y_i - μ_1(X_i)) - (1-T_i)/(1-e(X_i)) * (Y_i - μ_0(X_i))
242286

243287
term1 = mu1_X - mu0_X
244-
term2 = (T_jax / propensity_scores) * (y_jax - mu1_X)
245-
term3 = ((1 - T_jax) / (1 - propensity_scores)) * (y_jax - mu0_X)
288+
term2 = (W_jax / propensity_scores) * (y_jax - mu1_X)
289+
term3 = ((1 - W_jax) / (1 - propensity_scores)) * (y_jax - mu0_X)
246290

247291
psi_i = term1 + term2 - term3
248292

@@ -255,6 +299,12 @@ def summary(self) -> None:
255299
super().summary()
256300
if self.params and "ate" in self.params and self.params["ate"] is not None:
257301
print(f" Estimated ATE (AIPW): {self.params['ate']:.4f}")
258-
if self.params and "propensity_scores" in self.params and self.params["propensity_scores"] is not None:
259-
print(f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}")
302+
if (
303+
self.params
304+
and "propensity_scores" in self.params
305+
and self.params["propensity_scores"] is not None
306+
):
307+
print(
308+
f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}"
309+
)
260310
# Could add info about outcome model parameters if desired

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ dependencies = [
1212
"jaxopt",
1313
"lineax",
1414
"optax",
15-
"formulaic",
16-
"narwhals",
1715
]
1816
requires-python = ">=3.8"
1917
authors = [

0 commit comments

Comments
 (0)