11from typing import Dict , Optional , Any
22
3- import jax # Ensure jax is imported
4- import optax # Added for type hint
3+ import jax # Ensure jax is imported
4+ import optax # Added for type hint
55import jax .numpy as jnp
66from jaxopt import LBFGS
77
88from .base import BaseEstimator
9+ from .linear import LinearRegression # For default outcome model in AIPW
910
1011
1112class EntropyBalancing (BaseEstimator ):
@@ -55,8 +56,12 @@ class IPW(BaseEstimator):
5556 This implementation uses Logistic Regression for propensity score estimation.
5657 """
5758
58- def __init__ (self , propensity_optimizer : Optional [optax .GradientTransformation ] = None ,
59- propensity_maxiter : int = 5000 , ps_clip_epsilon : float = 1e-6 ):
59+ def __init__ (
60+ self ,
61+ propensity_optimizer : Optional [optax .GradientTransformation ] = None ,
62+ propensity_maxiter : int = 5000 ,
63+ ps_clip_epsilon : float = 1e-6 ,
64+ ):
6065 """
6166 Initialize the IPW estimator.
6267
@@ -68,12 +73,19 @@ def __init__(self, propensity_optimizer: Optional[optax.GradientTransformation]
6873 """
6974 super ().__init__ ()
7075 from .mle import LogisticRegression
71- self .logit_model = LogisticRegression (optimizer = propensity_optimizer , maxiter = propensity_maxiter )
76+
77+ self .logit_model = LogisticRegression (
78+ optimizer = propensity_optimizer , maxiter = propensity_maxiter
79+ )
7280 self .ps_clip_epsilon = ps_clip_epsilon
7381 self .params : Dict [str , Any ] = {"ate" : None , "propensity_scores" : None }
7482
75-
76- def fit (self , X : jnp .ndarray , T : jnp .ndarray , y : jnp .ndarray ) -> "IPW" :
83+ def fit (
84+ self ,
85+ X : jnp .ndarray ,
86+ W : jnp .ndarray ,
87+ y : jnp .ndarray ,
88+ ) -> "IPW" :
7789 """
7890 Estimate the Average Treatment Effect (ATE) using IPW.
7991
@@ -87,24 +99,26 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "IPW":
8799 The fitted estimator with ATE and propensity scores.
88100 """
89101 # Ensure T is jnp.ndarray for Logit model
90- if not isinstance (T , jnp .ndarray ):
91- T_jax = jnp .array (T )
102+ if not isinstance (W , jnp .ndarray ):
103+ W_jax = jnp .array (W )
92104 else :
93- T_jax = T
105+ W_jax = W
94106
95107 # 1. Estimate propensity scores P(T=1|X) using Logit
96- self .logit_model .fit (X , T_jax )
108+ self .logit_model .fit (X , W_jax )
97109 propensity_scores = self .logit_model .predict_proba (X )
98110
99111 # Clip propensity scores using the instance attribute
100- propensity_scores = jnp .clip (propensity_scores , self .ps_clip_epsilon , 1 - self .ps_clip_epsilon )
112+ propensity_scores = jnp .clip (
113+ propensity_scores , self .ps_clip_epsilon , 1 - self .ps_clip_epsilon
114+ )
101115
102116 self .params ["propensity_scores" ] = propensity_scores
103117
104118 # 2. Calculate IPW weights
105119 # Weight for treated: 1 / p_score
106120 # Weight for control: 1 / (1 - p_score)
107- weights = T_jax / propensity_scores + (1 - T_jax ) / (1 - propensity_scores )
121+ weights = W_jax / propensity_scores + (1 - W_jax ) / (1 - propensity_scores )
108122
109123 # 3. Estimate ATE: E[Y(1)] - E[Y(0)]
110124 # E[Y(1)] = sum(T_i * y_i / p_i) / sum(T_i / p_i)
@@ -120,8 +134,8 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "IPW":
120134 # This can also be seen as E[ (T - e)Y / (e(1-e)) ]
121135 # However, the difference of means of weighted outcomes is more standard:
122136
123- mean_y1 = jnp .sum (T_jax * y * weights ) / jnp .sum (T_jax * weights )
124- mean_y0 = jnp .sum ((1 - T_jax ) * y * weights ) / jnp .sum ((1 - T_jax ) * weights )
137+ mean_y1 = jnp .sum (W_jax * y * weights ) / jnp .sum (W_jax * weights )
138+ mean_y0 = jnp .sum ((1 - W_jax ) * y * weights ) / jnp .sum ((1 - W_jax ) * weights )
125139
126140 # The above is equivalent to:
127141 # mean_y1 = jnp.sum( (T_jax * y) / propensity_scores ) / jnp.sum( T_jax / propensity_scores )
@@ -133,28 +147,36 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "IPW":
133147 return self
134148
135149 def summary (self ) -> None :
136- super ().summary () # Calls BaseEstimator summary
150+ super ().summary () # Calls BaseEstimator summary
137151 if self .params and "ate" in self .params and self .params ["ate" ] is not None :
138152 print (f" Estimated ATE: { self .params ['ate' ]:.4f} " )
139- if self .params and "propensity_scores" in self .params and self .params ["propensity_scores" ] is not None :
140- print (f" Propensity scores min: { jnp .min (self .params ['propensity_scores' ]):.4f} , max: { jnp .max (self .params ['propensity_scores' ]):.4f} " )
153+ if (
154+ self .params
155+ and "propensity_scores" in self .params
156+ and self .params ["propensity_scores" ] is not None
157+ ):
158+ print (
159+ f" Propensity scores min: { jnp .min (self .params ['propensity_scores' ]):.4f} , max: { jnp .max (self .params ['propensity_scores' ]):.4f} "
160+ )
161+
141162
142163# Need to add `Any` to imports for type hinting
143164# from typing import Dict, Optional, Any
144165# Need to add this at the top of causal.py
145166
146- from .linear import LinearRegression # For default outcome model in AIPW
147167
148168class AIPW (BaseEstimator ):
149169 """
150170 Augmented Inverse Propensity Weighting (AIPW) estimator for ATE.
151171 Also known as doubly robust estimator.
152172 """
153173
154- def __init__ (self ,
155- outcome_model : Optional [BaseEstimator ] = None ,
156- propensity_model : Optional [Any ] = None , # Should be a Logit instance or similar
157- ps_clip_epsilon : float = 1e-6 ):
174+ def __init__ (
175+ self ,
176+ outcome_model : Optional [BaseEstimator ] = None ,
177+ propensity_model : Optional [Any ] = None , # Should be a Logit instance or similar
178+ ps_clip_epsilon : float = 1e-6 ,
179+ ):
158180 """
159181 Initialize the AIPW estimator.
160182
@@ -170,15 +192,28 @@ def __init__(self,
170192 super ().__init__ ()
171193 from .mle import LogisticRegression
172194
173- self .outcome_model_template = outcome_model if outcome_model else LinearRegression ()
195+ self .outcome_model_template = (
196+ outcome_model if outcome_model else LinearRegression ()
197+ )
174198 # We need two instances of the outcome model, one for T=1 and one for T=0
175- self .propensity_model = propensity_model if propensity_model else LogisticRegression ()
199+ self .propensity_model = (
200+ propensity_model if propensity_model else LogisticRegression ()
201+ )
176202
177203 self .ps_clip_epsilon = ps_clip_epsilon
178- self .params : Dict [str , Any ] = {"ate" : None , "propensity_scores" : None ,
179- "mu0_params" : None , "mu1_params" : None }
204+ self .params : Dict [str , Any ] = {
205+ "ate" : None ,
206+ "propensity_scores" : None ,
207+ "mu0_params" : None ,
208+ "mu1_params" : None ,
209+ }
180210
181- def fit (self , X : jnp .ndarray , T : jnp .ndarray , y : jnp .ndarray ) -> "AIPW" :
211+ def fit (
212+ self ,
213+ X : jnp .ndarray ,
214+ W : jnp .ndarray ,
215+ y : jnp .ndarray ,
216+ ) -> "AIPW" :
182217 """
183218 Estimate the Average Treatment Effect (ATE) using AIPW.
184219
@@ -191,58 +226,67 @@ def fit(self, X: jnp.ndarray, T: jnp.ndarray, y: jnp.ndarray) -> "AIPW":
191226 Returns:
192227 The fitted estimator with ATE.
193228 """
194- if not isinstance (T , jnp .ndarray ): T_jax = jnp .array (T )
195- else : T_jax = T
196- if not isinstance (y , jnp .ndarray ): y_jax = jnp .array (y )
197- else : y_jax = y
198- if not isinstance (X , jnp .ndarray ): X_jax = jnp .array (X )
199- else : X_jax = X
229+ if not isinstance (W , jnp .ndarray ):
230+ W_jax = jnp .array (W )
231+ else :
232+ W_jax = W
233+ if not isinstance (y , jnp .ndarray ):
234+ y_jax = jnp .array (y )
235+ else :
236+ y_jax = y
237+ if not isinstance (X , jnp .ndarray ):
238+ X_jax = jnp .array (X )
239+ else :
240+ X_jax = X
200241
201242 n_samples = X_jax .shape [0 ]
202243
203244 # 1. Estimate propensity scores P(T=1|X) = e(X)
204- self .propensity_model .fit (X_jax , T_jax )
245+ self .propensity_model .fit (X_jax , W_jax )
205246 propensity_scores = self .propensity_model .predict_proba (X_jax )
206- propensity_scores = jnp .clip (propensity_scores , self .ps_clip_epsilon , 1 - self .ps_clip_epsilon )
247+ propensity_scores = jnp .clip (
248+ propensity_scores , self .ps_clip_epsilon , 1 - self .ps_clip_epsilon
249+ )
207250 self .params ["propensity_scores" ] = propensity_scores
208251
209252 # 2. Estimate outcome models E[Y|X, T=1] = μ_1(X) and E[Y|X, T=0] = μ_0(X)
210253 # Need to handle potential issues if one group has no samples (though unlikely with real data)
211- X_treated = X_jax [T_jax == 1 ]
212- y_treated = y_jax [T_jax == 1 ]
213- X_control = X_jax [T_jax == 0 ]
214- y_control = y_jax [T_jax == 0 ]
254+ X_treated = X_jax [W_jax == 1 ]
255+ y_treated = y_jax [W_jax == 1 ]
256+ X_control = X_jax [W_jax == 0 ]
257+ y_control = y_jax [W_jax == 0 ]
215258
216259 # Create fresh instances of the outcome model for fitting
217260 # This assumes the outcome_model_template can be re-used (e.g. by creating a new instance or being stateless after fit)
218261 # For sklearn-like models, this means creating new instances.
219262 # For our JAX models, they are re-fitted.
220263
221- model1 = self .outcome_model_template .__class__ () # Create a new instance of the same type
264+ model1 = (
265+ self .outcome_model_template .__class__ ()
266+ ) # Create a new instance of the same type
222267 if X_treated .shape [0 ] > 0 :
223268 model1 .fit (X_treated , y_treated )
224269 mu1_X = model1 .predict (X_jax )
225270 self .params ["mu1_params" ] = model1 .params
226- else : # Should not happen in typical scenarios
271+ else : # Should not happen in typical scenarios
227272 mu1_X = jnp .zeros (n_samples )
228273 self .params ["mu1_params" ] = None
229274
230- model0 = self .outcome_model_template .__class__ () # Create a new instance
275+ model0 = self .outcome_model_template .__class__ () # Create a new instance
231276 if X_control .shape [0 ] > 0 :
232277 model0 .fit (X_control , y_control )
233278 mu0_X = model0 .predict (X_jax )
234279 self .params ["mu0_params" ] = model0 .params
235- else : # Should not happen
280+ else : # Should not happen
236281 mu0_X = jnp .zeros (n_samples )
237282 self .params ["mu0_params" ] = None
238283
239-
240284 # 3. Calculate AIPW estimator components
241285 # ψ_i = μ_1(X_i) - μ_0(X_i) + T_i/e(X_i) * (Y_i - μ_1(X_i)) - (1-T_i)/(1-e(X_i)) * (Y_i - μ_0(X_i))
242286
243287 term1 = mu1_X - mu0_X
244- term2 = (T_jax / propensity_scores ) * (y_jax - mu1_X )
245- term3 = ((1 - T_jax ) / (1 - propensity_scores )) * (y_jax - mu0_X )
288+ term2 = (W_jax / propensity_scores ) * (y_jax - mu1_X )
289+ term3 = ((1 - W_jax ) / (1 - propensity_scores )) * (y_jax - mu0_X )
246290
247291 psi_i = term1 + term2 - term3
248292
@@ -255,6 +299,12 @@ def summary(self) -> None:
255299 super ().summary ()
256300 if self .params and "ate" in self .params and self .params ["ate" ] is not None :
257301 print (f" Estimated ATE (AIPW): { self .params ['ate' ]:.4f} " )
258- if self .params and "propensity_scores" in self .params and self .params ["propensity_scores" ] is not None :
259- print (f" Propensity scores min: { jnp .min (self .params ['propensity_scores' ]):.4f} , max: { jnp .max (self .params ['propensity_scores' ]):.4f} " )
302+ if (
303+ self .params
304+ and "propensity_scores" in self .params
305+ and self .params ["propensity_scores" ] is not None
306+ ):
307+ print (
308+ f" Propensity scores min: { jnp .min (self .params ['propensity_scores' ]):.4f} , max: { jnp .max (self .params ['propensity_scores' ]):.4f} "
309+ )
260310 # Could add info about outcome model parameters if desired
0 commit comments