Skip to content

Commit 8e04850

Browse files
authored
Merge pull request #2 from apoorvalal/feat/mle-causal-methods
feat: Implement MLE and IPW/AIPW
2 parents 0ff479b + 68022a0 commit 8e04850

File tree

10 files changed

+857
-32
lines changed

10 files changed

+857
-32
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
*.blg
2828
*.out
2929
*.synctex.gz
30+
nb/scratch.ipynb

jaxonometrics/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@
55
__version__ = "0.0.1"
66

77
from .base import BaseEstimator
8-
from .causal import EntropyBalancing
8+
from .causal import EntropyBalancing, IPW, AIPW # Added IPW, AIPW
99
from .gmm import GMM, LinearIVGMM, TwoStepGMM
1010
from .linear import LinearRegression
11+
from .mle import LogisticRegression, PoissonRegression, MaximumLikelihoodEstimator # Added MLE models
1112

1213
__all__ = [
1314
"BaseEstimator",
1415
"EntropyBalancing",
16+
"IPW", # Added
17+
"AIPW", # Added
1518
"GMM",
1619
"LinearIVGMM",
1720
"TwoStepGMM",
1821
"LinearRegression",
22+
"MaximumLikelihoodEstimator",
23+
"LogisticRegression",
24+
"PoissonRegression",
1925
]

jaxonometrics/causal.py

Lines changed: 265 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Optional, Any
22

3+
import jax # Ensure jax is imported
4+
import optax # Added for type hint
35
import jax.numpy as jnp
46
from jaxopt import LBFGS
57

68
from .base import BaseEstimator
9+
from .linear import LinearRegression # For default outcome model in AIPW
710

811

912
class EntropyBalancing(BaseEstimator):
@@ -19,6 +22,7 @@ def __init__(self):
1922
super().__init__()
2023

2124
@staticmethod
25+
@jax.jit
2226
def _eb_moment(b: jnp.ndarray, X0: jnp.ndarray, X1: jnp.ndarray) -> jnp.ndarray:
2327
"""The moment condition for entropy balancing."""
2428
return jnp.log(jnp.exp(-1 * X0 @ b).sum()) + X1 @ b
@@ -44,3 +48,263 @@ def fit(
4448
wt /= wt.sum()
4549
self.params = {"weights": wt}
4650
return self
51+
52+
53+
class IPW(BaseEstimator):
54+
"""
55+
Inverse Propensity Weighting estimator for Average Treatment Effect (ATE).
56+
This implementation uses Logistic Regression for propensity score estimation.
57+
"""
58+
59+
def __init__(
60+
self,
61+
propensity_optimizer: Optional[optax.GradientTransformation] = None,
62+
propensity_maxiter: int = 5000,
63+
ps_clip_epsilon: float = 1e-6,
64+
):
65+
"""
66+
Initialize the IPW estimator.
67+
68+
Args:
69+
propensity_optimizer: Optional Optax optimizer for the Logit propensity score model.
70+
Defaults to optax.adam(1e-3) if None.
71+
propensity_maxiter: Maximum iterations for the Logit model optimization.
72+
ps_clip_epsilon: Small constant to clip propensity scores.
73+
"""
74+
super().__init__()
75+
from .mle import LogisticRegression
76+
77+
self.logit_model = LogisticRegression(
78+
optimizer=propensity_optimizer, maxiter=propensity_maxiter
79+
)
80+
self.ps_clip_epsilon = ps_clip_epsilon
81+
self.params: Dict[str, Any] = {"ate": None, "propensity_scores": None}
82+
83+
def fit(
84+
self,
85+
X: jnp.ndarray,
86+
W: jnp.ndarray,
87+
y: jnp.ndarray,
88+
) -> "IPW":
89+
"""
90+
Estimate the Average Treatment Effect (ATE) using IPW.
91+
92+
Args:
93+
X: Covariate matrix of shape (n_samples, n_features).
94+
It's assumed that X includes an intercept column if one is desired for the propensity score model.
95+
T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,).
96+
y: Outcome vector of shape (n_samples,).
97+
98+
Returns:
99+
The fitted estimator with ATE and propensity scores.
100+
"""
101+
# Ensure T is jnp.ndarray for Logit model
102+
if not isinstance(W, jnp.ndarray):
103+
W_jax = jnp.array(W)
104+
else:
105+
W_jax = W
106+
107+
# 1. Estimate propensity scores P(T=1|X) using Logit
108+
self.logit_model.fit(X, W_jax)
109+
propensity_scores = self.logit_model.predict_proba(X)
110+
111+
# Clip propensity scores using the instance attribute
112+
propensity_scores = jnp.clip(
113+
propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon
114+
)
115+
116+
self.params["propensity_scores"] = propensity_scores
117+
118+
# 2. Calculate IPW weights
119+
# Weight for treated: 1 / p_score
120+
# Weight for control: 1 / (1 - p_score)
121+
weights = W_jax / propensity_scores + (1 - W_jax) / (1 - propensity_scores)
122+
123+
# 3. Estimate ATE: E[Y(1)] - E[Y(0)]
124+
# E[Y(1)] = sum(T_i * y_i / p_i) / sum(T_i / p_i)
125+
# E[Y(0)] = sum((1-T_i) * y_i / (1-p_i)) / sum((1-T_i) / (1-p_i))
126+
# ATE = sum( (T_i/p_i - (1-T_i)/(1-p_i)) * y_i ) / N (Hahn, 1998 type estimator)
127+
# Or, more commonly for Horvitz-Thompson type:
128+
# E[Y_1] = sum(T*y / ps) / sum(T/ps)
129+
# E[Y_0] = sum((1-T)*y / (1-ps)) / sum((1-T)/(1-ps))
130+
# ATE = E[Y_1] - E[Y_0]
131+
132+
# Using the simpler weighted average formulation for ATE:
133+
# ATE = (1/N) * Σ [ (T_i * Y_i / e(X_i)) - ((1-T_i) * Y_i / (1-e(X_i))) ]
134+
# This can also be seen as E[ (T - e)Y / (e(1-e)) ]
135+
# However, the difference of means of weighted outcomes is more standard:
136+
137+
mean_y1 = jnp.sum(W_jax * y * weights) / jnp.sum(W_jax * weights)
138+
mean_y0 = jnp.sum((1 - W_jax) * y * weights) / jnp.sum((1 - W_jax) * weights)
139+
140+
# The above is equivalent to:
141+
# mean_y1 = jnp.sum( (T_jax * y) / propensity_scores ) / jnp.sum( T_jax / propensity_scores )
142+
# mean_y0 = jnp.sum( ((1-T_jax) * y) / (1-propensity_scores) ) / jnp.sum( (1-T_jax) / (1-propensity_scores) )
143+
144+
ate = mean_y1 - mean_y0
145+
self.params["ate"] = ate
146+
147+
return self
148+
149+
def summary(self) -> None:
150+
super().summary() # Calls BaseEstimator summary
151+
if self.params and "ate" in self.params and self.params["ate"] is not None:
152+
print(f" Estimated ATE: {self.params['ate']:.4f}")
153+
if (
154+
self.params
155+
and "propensity_scores" in self.params
156+
and self.params["propensity_scores"] is not None
157+
):
158+
print(
159+
f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}"
160+
)
161+
162+
163+
# Need to add `Any` to imports for type hinting
164+
# from typing import Dict, Optional, Any
165+
# Need to add this at the top of causal.py
166+
167+
168+
class AIPW(BaseEstimator):
169+
"""
170+
Augmented Inverse Propensity Weighting (AIPW) estimator for ATE.
171+
Also known as doubly robust estimator.
172+
"""
173+
174+
def __init__(
175+
self,
176+
outcome_model: Optional[BaseEstimator] = None,
177+
propensity_model: Optional[Any] = None, # Should be a Logit instance or similar
178+
ps_clip_epsilon: float = 1e-6,
179+
):
180+
"""
181+
Initialize the AIPW estimator.
182+
183+
Args:
184+
outcome_model: A regression model (like LinearRegression or a custom one)
185+
to estimate E[Y|X, T=t]. If None, LinearRegression is used.
186+
The model should have a `fit(X,y)` and `predict(X)` method.
187+
propensity_model: A binary classifier (like Logit) to estimate P(T=1|X).
188+
If None, Logit() is used. Model should have `fit(X,T)`
189+
and `predict_proba(X)` methods.
190+
ps_clip_epsilon: Small constant to clip propensity scores to avoid extreme values.
191+
"""
192+
super().__init__()
193+
from .mle import LogisticRegression
194+
195+
self.outcome_model_template = (
196+
outcome_model if outcome_model else LinearRegression()
197+
)
198+
# We need two instances of the outcome model, one for T=1 and one for T=0
199+
self.propensity_model = (
200+
propensity_model if propensity_model else LogisticRegression()
201+
)
202+
203+
self.ps_clip_epsilon = ps_clip_epsilon
204+
self.params: Dict[str, Any] = {
205+
"ate": None,
206+
"propensity_scores": None,
207+
"mu0_params": None,
208+
"mu1_params": None,
209+
}
210+
211+
def fit(
212+
self,
213+
X: jnp.ndarray,
214+
W: jnp.ndarray,
215+
y: jnp.ndarray,
216+
) -> "AIPW":
217+
"""
218+
Estimate the Average Treatment Effect (ATE) using AIPW.
219+
220+
Args:
221+
X: Covariate matrix of shape (n_samples, n_features).
222+
It's assumed that X includes an intercept column if one is desired for the outcome and propensity score models.
223+
T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,).
224+
y: Outcome vector of shape (n_samples,).
225+
226+
Returns:
227+
The fitted estimator with ATE.
228+
"""
229+
if not isinstance(W, jnp.ndarray):
230+
W_jax = jnp.array(W)
231+
else:
232+
W_jax = W
233+
if not isinstance(y, jnp.ndarray):
234+
y_jax = jnp.array(y)
235+
else:
236+
y_jax = y
237+
if not isinstance(X, jnp.ndarray):
238+
X_jax = jnp.array(X)
239+
else:
240+
X_jax = X
241+
242+
n_samples = X_jax.shape[0]
243+
244+
# 1. Estimate propensity scores P(T=1|X) = e(X)
245+
self.propensity_model.fit(X_jax, W_jax)
246+
propensity_scores = self.propensity_model.predict_proba(X_jax)
247+
propensity_scores = jnp.clip(
248+
propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon
249+
)
250+
self.params["propensity_scores"] = propensity_scores
251+
252+
# 2. Estimate outcome models E[Y|X, T=1] = μ_1(X) and E[Y|X, T=0] = μ_0(X)
253+
# Need to handle potential issues if one group has no samples (though unlikely with real data)
254+
X_treated = X_jax[W_jax == 1]
255+
y_treated = y_jax[W_jax == 1]
256+
X_control = X_jax[W_jax == 0]
257+
y_control = y_jax[W_jax == 0]
258+
259+
# Create fresh instances of the outcome model for fitting
260+
# This assumes the outcome_model_template can be re-used (e.g. by creating a new instance or being stateless after fit)
261+
# For sklearn-like models, this means creating new instances.
262+
# For our JAX models, they are re-fitted.
263+
264+
model1 = (
265+
self.outcome_model_template.__class__()
266+
) # Create a new instance of the same type
267+
if X_treated.shape[0] > 0:
268+
model1.fit(X_treated, y_treated)
269+
mu1_X = model1.predict(X_jax)
270+
self.params["mu1_params"] = model1.params
271+
else: # Should not happen in typical scenarios
272+
mu1_X = jnp.zeros(n_samples)
273+
self.params["mu1_params"] = None
274+
275+
model0 = self.outcome_model_template.__class__() # Create a new instance
276+
if X_control.shape[0] > 0:
277+
model0.fit(X_control, y_control)
278+
mu0_X = model0.predict(X_jax)
279+
self.params["mu0_params"] = model0.params
280+
else: # Should not happen
281+
mu0_X = jnp.zeros(n_samples)
282+
self.params["mu0_params"] = None
283+
284+
# 3. Calculate AIPW estimator components
285+
# ψ_i = μ_1(X_i) - μ_0(X_i) + T_i/e(X_i) * (Y_i - μ_1(X_i)) - (1-T_i)/(1-e(X_i)) * (Y_i - μ_0(X_i))
286+
287+
term1 = mu1_X - mu0_X
288+
term2 = (W_jax / propensity_scores) * (y_jax - mu1_X)
289+
term3 = ((1 - W_jax) / (1 - propensity_scores)) * (y_jax - mu0_X)
290+
291+
psi_i = term1 + term2 - term3
292+
293+
ate = jnp.mean(psi_i)
294+
self.params["ate"] = ate
295+
296+
return self
297+
298+
def summary(self) -> None:
299+
super().summary()
300+
if self.params and "ate" in self.params and self.params["ate"] is not None:
301+
print(f" Estimated ATE (AIPW): {self.params['ate']:.4f}")
302+
if (
303+
self.params
304+
and "propensity_scores" in self.params
305+
and self.params["propensity_scores"] is not None
306+
):
307+
print(
308+
f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}"
309+
)
310+
# Could add info about outcome model parameters if desired

0 commit comments

Comments
 (0)