1- from typing import Dict , Optional
1+ from typing import Dict , Optional , Any
22
3+ import jax # Ensure jax is imported
4+ import optax # Added for type hint
35import jax .numpy as jnp
46from jaxopt import LBFGS
57
68from .base import BaseEstimator
9+ from .linear import LinearRegression # For default outcome model in AIPW
710
811
912class EntropyBalancing (BaseEstimator ):
@@ -19,6 +22,7 @@ def __init__(self):
1922 super ().__init__ ()
2023
2124 @staticmethod
25+ @jax .jit
2226 def _eb_moment (b : jnp .ndarray , X0 : jnp .ndarray , X1 : jnp .ndarray ) -> jnp .ndarray :
2327 """The moment condition for entropy balancing."""
2428 return jnp .log (jnp .exp (- 1 * X0 @ b ).sum ()) + X1 @ b
@@ -44,3 +48,263 @@ def fit(
4448 wt /= wt .sum ()
4549 self .params = {"weights" : wt }
4650 return self
51+
52+
53+ class IPW (BaseEstimator ):
54+ """
55+ Inverse Propensity Weighting estimator for Average Treatment Effect (ATE).
56+ This implementation uses Logistic Regression for propensity score estimation.
57+ """
58+
59+ def __init__ (
60+ self ,
61+ propensity_optimizer : Optional [optax .GradientTransformation ] = None ,
62+ propensity_maxiter : int = 5000 ,
63+ ps_clip_epsilon : float = 1e-6 ,
64+ ):
65+ """
66+ Initialize the IPW estimator.
67+
68+ Args:
69+ propensity_optimizer: Optional Optax optimizer for the Logit propensity score model.
70+ Defaults to optax.adam(1e-3) if None.
71+ propensity_maxiter: Maximum iterations for the Logit model optimization.
72+ ps_clip_epsilon: Small constant to clip propensity scores.
73+ """
74+ super ().__init__ ()
75+ from .mle import LogisticRegression
76+
77+ self .logit_model = LogisticRegression (
78+ optimizer = propensity_optimizer , maxiter = propensity_maxiter
79+ )
80+ self .ps_clip_epsilon = ps_clip_epsilon
81+ self .params : Dict [str , Any ] = {"ate" : None , "propensity_scores" : None }
82+
83+ def fit (
84+ self ,
85+ X : jnp .ndarray ,
86+ W : jnp .ndarray ,
87+ y : jnp .ndarray ,
88+ ) -> "IPW" :
89+ """
90+ Estimate the Average Treatment Effect (ATE) using IPW.
91+
92+ Args:
93+ X: Covariate matrix of shape (n_samples, n_features).
94+ It's assumed that X includes an intercept column if one is desired for the propensity score model.
95+ T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,).
96+ y: Outcome vector of shape (n_samples,).
97+
98+ Returns:
99+ The fitted estimator with ATE and propensity scores.
100+ """
101+ # Ensure T is jnp.ndarray for Logit model
102+ if not isinstance (W , jnp .ndarray ):
103+ W_jax = jnp .array (W )
104+ else :
105+ W_jax = W
106+
107+ # 1. Estimate propensity scores P(T=1|X) using Logit
108+ self .logit_model .fit (X , W_jax )
109+ propensity_scores = self .logit_model .predict_proba (X )
110+
111+ # Clip propensity scores using the instance attribute
112+ propensity_scores = jnp .clip (
113+ propensity_scores , self .ps_clip_epsilon , 1 - self .ps_clip_epsilon
114+ )
115+
116+ self .params ["propensity_scores" ] = propensity_scores
117+
118+ # 2. Calculate IPW weights
119+ # Weight for treated: 1 / p_score
120+ # Weight for control: 1 / (1 - p_score)
121+ weights = W_jax / propensity_scores + (1 - W_jax ) / (1 - propensity_scores )
122+
123+ # 3. Estimate ATE: E[Y(1)] - E[Y(0)]
124+ # E[Y(1)] = sum(T_i * y_i / p_i) / sum(T_i / p_i)
125+ # E[Y(0)] = sum((1-T_i) * y_i / (1-p_i)) / sum((1-T_i) / (1-p_i))
126+ # ATE = sum( (T_i/p_i - (1-T_i)/(1-p_i)) * y_i ) / N (Hahn, 1998 type estimator)
127+ # Or, more commonly for Horvitz-Thompson type:
128+ # E[Y_1] = sum(T*y / ps) / sum(T/ps)
129+ # E[Y_0] = sum((1-T)*y / (1-ps)) / sum((1-T)/(1-ps))
130+ # ATE = E[Y_1] - E[Y_0]
131+
132+ # Using the simpler weighted average formulation for ATE:
133+ # ATE = (1/N) * Σ [ (T_i * Y_i / e(X_i)) - ((1-T_i) * Y_i / (1-e(X_i))) ]
134+ # This can also be seen as E[ (T - e)Y / (e(1-e)) ]
135+ # However, the difference of means of weighted outcomes is more standard:
136+
137+ mean_y1 = jnp .sum (W_jax * y * weights ) / jnp .sum (W_jax * weights )
138+ mean_y0 = jnp .sum ((1 - W_jax ) * y * weights ) / jnp .sum ((1 - W_jax ) * weights )
139+
140+ # The above is equivalent to:
141+ # mean_y1 = jnp.sum( (T_jax * y) / propensity_scores ) / jnp.sum( T_jax / propensity_scores )
142+ # mean_y0 = jnp.sum( ((1-T_jax) * y) / (1-propensity_scores) ) / jnp.sum( (1-T_jax) / (1-propensity_scores) )
143+
144+ ate = mean_y1 - mean_y0
145+ self .params ["ate" ] = ate
146+
147+ return self
148+
149+ def summary (self ) -> None :
150+ super ().summary () # Calls BaseEstimator summary
151+ if self .params and "ate" in self .params and self .params ["ate" ] is not None :
152+ print (f" Estimated ATE: { self .params ['ate' ]:.4f} " )
153+ if (
154+ self .params
155+ and "propensity_scores" in self .params
156+ and self .params ["propensity_scores" ] is not None
157+ ):
158+ print (
159+ f" Propensity scores min: { jnp .min (self .params ['propensity_scores' ]):.4f} , max: { jnp .max (self .params ['propensity_scores' ]):.4f} "
160+ )
161+
162+
163+ # Need to add `Any` to imports for type hinting
164+ # from typing import Dict, Optional, Any
165+ # Need to add this at the top of causal.py
166+
167+
168+ class AIPW (BaseEstimator ):
169+ """
170+ Augmented Inverse Propensity Weighting (AIPW) estimator for ATE.
171+ Also known as doubly robust estimator.
172+ """
173+
174+ def __init__ (
175+ self ,
176+ outcome_model : Optional [BaseEstimator ] = None ,
177+ propensity_model : Optional [Any ] = None , # Should be a Logit instance or similar
178+ ps_clip_epsilon : float = 1e-6 ,
179+ ):
180+ """
181+ Initialize the AIPW estimator.
182+
183+ Args:
184+ outcome_model: A regression model (like LinearRegression or a custom one)
185+ to estimate E[Y|X, T=t]. If None, LinearRegression is used.
186+ The model should have a `fit(X,y)` and `predict(X)` method.
187+ propensity_model: A binary classifier (like Logit) to estimate P(T=1|X).
188+ If None, Logit() is used. Model should have `fit(X,T)`
189+ and `predict_proba(X)` methods.
190+ ps_clip_epsilon: Small constant to clip propensity scores to avoid extreme values.
191+ """
192+ super ().__init__ ()
193+ from .mle import LogisticRegression
194+
195+ self .outcome_model_template = (
196+ outcome_model if outcome_model else LinearRegression ()
197+ )
198+ # We need two instances of the outcome model, one for T=1 and one for T=0
199+ self .propensity_model = (
200+ propensity_model if propensity_model else LogisticRegression ()
201+ )
202+
203+ self .ps_clip_epsilon = ps_clip_epsilon
204+ self .params : Dict [str , Any ] = {
205+ "ate" : None ,
206+ "propensity_scores" : None ,
207+ "mu0_params" : None ,
208+ "mu1_params" : None ,
209+ }
210+
211+ def fit (
212+ self ,
213+ X : jnp .ndarray ,
214+ W : jnp .ndarray ,
215+ y : jnp .ndarray ,
216+ ) -> "AIPW" :
217+ """
218+ Estimate the Average Treatment Effect (ATE) using AIPW.
219+
220+ Args:
221+ X: Covariate matrix of shape (n_samples, n_features).
222+ It's assumed that X includes an intercept column if one is desired for the outcome and propensity score models.
223+ T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,).
224+ y: Outcome vector of shape (n_samples,).
225+
226+ Returns:
227+ The fitted estimator with ATE.
228+ """
229+ if not isinstance (W , jnp .ndarray ):
230+ W_jax = jnp .array (W )
231+ else :
232+ W_jax = W
233+ if not isinstance (y , jnp .ndarray ):
234+ y_jax = jnp .array (y )
235+ else :
236+ y_jax = y
237+ if not isinstance (X , jnp .ndarray ):
238+ X_jax = jnp .array (X )
239+ else :
240+ X_jax = X
241+
242+ n_samples = X_jax .shape [0 ]
243+
244+ # 1. Estimate propensity scores P(T=1|X) = e(X)
245+ self .propensity_model .fit (X_jax , W_jax )
246+ propensity_scores = self .propensity_model .predict_proba (X_jax )
247+ propensity_scores = jnp .clip (
248+ propensity_scores , self .ps_clip_epsilon , 1 - self .ps_clip_epsilon
249+ )
250+ self .params ["propensity_scores" ] = propensity_scores
251+
252+ # 2. Estimate outcome models E[Y|X, T=1] = μ_1(X) and E[Y|X, T=0] = μ_0(X)
253+ # Need to handle potential issues if one group has no samples (though unlikely with real data)
254+ X_treated = X_jax [W_jax == 1 ]
255+ y_treated = y_jax [W_jax == 1 ]
256+ X_control = X_jax [W_jax == 0 ]
257+ y_control = y_jax [W_jax == 0 ]
258+
259+ # Create fresh instances of the outcome model for fitting
260+ # This assumes the outcome_model_template can be re-used (e.g. by creating a new instance or being stateless after fit)
261+ # For sklearn-like models, this means creating new instances.
262+ # For our JAX models, they are re-fitted.
263+
264+ model1 = (
265+ self .outcome_model_template .__class__ ()
266+ ) # Create a new instance of the same type
267+ if X_treated .shape [0 ] > 0 :
268+ model1 .fit (X_treated , y_treated )
269+ mu1_X = model1 .predict (X_jax )
270+ self .params ["mu1_params" ] = model1 .params
271+ else : # Should not happen in typical scenarios
272+ mu1_X = jnp .zeros (n_samples )
273+ self .params ["mu1_params" ] = None
274+
275+ model0 = self .outcome_model_template .__class__ () # Create a new instance
276+ if X_control .shape [0 ] > 0 :
277+ model0 .fit (X_control , y_control )
278+ mu0_X = model0 .predict (X_jax )
279+ self .params ["mu0_params" ] = model0 .params
280+ else : # Should not happen
281+ mu0_X = jnp .zeros (n_samples )
282+ self .params ["mu0_params" ] = None
283+
284+ # 3. Calculate AIPW estimator components
285+ # ψ_i = μ_1(X_i) - μ_0(X_i) + T_i/e(X_i) * (Y_i - μ_1(X_i)) - (1-T_i)/(1-e(X_i)) * (Y_i - μ_0(X_i))
286+
287+ term1 = mu1_X - mu0_X
288+ term2 = (W_jax / propensity_scores ) * (y_jax - mu1_X )
289+ term3 = ((1 - W_jax ) / (1 - propensity_scores )) * (y_jax - mu0_X )
290+
291+ psi_i = term1 + term2 - term3
292+
293+ ate = jnp .mean (psi_i )
294+ self .params ["ate" ] = ate
295+
296+ return self
297+
298+ def summary (self ) -> None :
299+ super ().summary ()
300+ if self .params and "ate" in self .params and self .params ["ate" ] is not None :
301+ print (f" Estimated ATE (AIPW): { self .params ['ate' ]:.4f} " )
302+ if (
303+ self .params
304+ and "propensity_scores" in self .params
305+ and self .params ["propensity_scores" ] is not None
306+ ):
307+ print (
308+ f" Propensity scores min: { jnp .min (self .params ['propensity_scores' ]):.4f} , max: { jnp .max (self .params ['propensity_scores' ]):.4f} "
309+ )
310+ # Could add info about outcome model parameters if desired
0 commit comments