forked from AMLab-Amsterdam/CEVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
33 lines (29 loc) · 1.78 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
from sklearn.model_selection import train_test_split
class IHDP(object):
def __init__(self, path_data="datasets/IHDP/csv", replications=10):
self.path_data = path_data
self.replications = replications
# which features are binary
self.binfeats = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
# which features are continuous
self.contfeats = [i for i in range(25) if i not in self.binfeats]
def __iter__(self):
for i in range(self.replications):
data = np.loadtxt(self.path_data + '/ihdp_npci_' + str(i + 1) + '.csv', delimiter=',')
t, y, y_cf = data[:, 0], data[:, 1][:, np.newaxis], data[:, 2][:, np.newaxis]
mu_0, mu_1, x = data[:, 3][:, np.newaxis], data[:, 4][:, np.newaxis], data[:, 5:]
yield (x, t, y), (y_cf, mu_0, mu_1)
def get_train_valid_test(self):
for i in range(self.replications):
data = np.loadtxt(self.path_data + '/ihdp_npci_' + str(i + 1) + '.csv', delimiter=',')
t, y, y_cf = data[:, 0][:, np.newaxis], data[:, 1][:, np.newaxis], data[:, 2][:, np.newaxis]
mu_0, mu_1, x = data[:, 3][:, np.newaxis], data[:, 4][:, np.newaxis], data[:, 5:]
# this binary feature is in {1, 2}
x[:, 13] -= 1
idxtrain, ite = train_test_split(np.arange(x.shape[0]), test_size=0.1, random_state=1)
itr, iva = train_test_split(idxtrain, test_size=0.3, random_state=1)
train = (x[itr], t[itr], y[itr]), (y_cf[itr], mu_0[itr], mu_1[itr])
valid = (x[iva], t[iva], y[iva]), (y_cf[iva], mu_0[iva], mu_1[iva])
test = (x[ite], t[ite], y[ite]), (y_cf[ite], mu_0[ite], mu_1[ite])
yield train, valid, test, self.contfeats, self.binfeats