20
20
import arviz as az
21
21
import numpy as np
22
22
import pandas as pd
23
+ import xarray as xr
23
24
from matplotlib import pyplot as plt
24
25
from patsy import build_design_matrices , dmatrices
25
26
from sklearn .base import RegressorMixin
@@ -84,6 +85,8 @@ def __init__(
84
85
** kwargs ,
85
86
) -> None :
86
87
super ().__init__ (model = model )
88
+ # rename the index to "obs_ind"
89
+ data .index .name = "obs_ind"
87
90
self .input_validation (data , treatment_time )
88
91
self .treatment_time = treatment_time
89
92
# set experiment type - usually done in subclasses
@@ -107,6 +110,33 @@ def __init__(
107
110
)
108
111
self .post_X = np .asarray (new_x )
109
112
self .post_y = np .asarray (new_y )
113
+ # turn into xarray.DataArray's
114
+ self .pre_X = xr .DataArray (
115
+ self .pre_X ,
116
+ dims = ["obs_ind" , "coeffs" ],
117
+ coords = {
118
+ "obs_ind" : self .datapre .index ,
119
+ "coeffs" : self .labels ,
120
+ },
121
+ )
122
+ self .pre_y = xr .DataArray (
123
+ self .pre_y [:, 0 ],
124
+ dims = ["obs_ind" ],
125
+ coords = {"obs_ind" : self .datapre .index },
126
+ )
127
+ self .post_X = xr .DataArray (
128
+ self .post_X ,
129
+ dims = ["obs_ind" , "coeffs" ],
130
+ coords = {
131
+ "obs_ind" : self .datapost .index ,
132
+ "coeffs" : self .labels ,
133
+ },
134
+ )
135
+ self .post_y = xr .DataArray (
136
+ self .post_y [:, 0 ],
137
+ dims = ["obs_ind" ],
138
+ coords = {"obs_ind" : self .datapost .index },
139
+ )
110
140
111
141
# fit the model to the observed (pre-intervention) data
112
142
if isinstance (self .model , PyMCModel ):
@@ -125,10 +155,8 @@ def __init__(
125
155
126
156
# calculate the counterfactual
127
157
self .post_pred = self .model .predict (X = self .post_X )
128
- self .pre_impact = self .model .calculate_impact (self .pre_y [:, 0 ], self .pre_pred )
129
- self .post_impact = self .model .calculate_impact (
130
- self .post_y [:, 0 ], self .post_pred
131
- )
158
+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
159
+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
132
160
self .post_impact_cumulative = self .model .calculate_cumulative_impact (
133
161
self .post_impact
134
162
)
0 commit comments