-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpopulation_std.py
More file actions
118 lines (80 loc) · 3.91 KB
/
population_std.py
File metadata and controls
118 lines (80 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import numpy as np
import itertools
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils import data
from scipy.stats import binned_statistic
from scipy.signal import savgol_filter
import argparse
from pathlib import Path
import sys
file = Path(__file__). resolve()
package_root_directory = file.parents [1]
sys.path.append(str(package_root_directory))
from DataLoader.dataset import Dataset
parser = argparse.ArgumentParser('Pop_std')
parser.add_argument('--dataset',type=str,choices=['elsa','sample'],default='elsa',help='the dataset that will be used to train the model; either \'elsa\' or \'sample\'')
args = parser.parse_args()
postfix = '_sample' if args.dataset == 'sample' else ''
dir = os.path.dirname(os.path.realpath(__file__))
def nan_helper(y):
return np.isnan(y), lambda z: z.nonzero()[0]
device = 'cpu'
N = 29
dt = 0.5
train_name = f'{dir}/Data/train{postfix}.csv'
training_set = Dataset(train_name, N, pop=True)
num_train = training_set.__len__()
training_generator = data.DataLoader(training_set,
batch_size = num_train,
shuffle = True, drop_last = True)
mean_T = training_set.mean_T
std_T = training_set.std_T
age_bins = np.arange(40, 105, 3)
bin_centers = age_bins[1:] - np.diff(age_bins)/2.0
avg = np.zeros((2, bin_centers.shape[0], N + 1))
avg_smooth = np.zeros((2, bin_centers.shape[0], N + 1))
avg_env = np.zeros((2, bin_centers.shape[0], 2))
avg_env_smooth = np.zeros((2, bin_centers.shape[0], 2))
for batch_data, batch_times, batch_mask, batch_survival_mask, _,_, batch_censored, _, batch_env, batch_med, batch_weights in training_generator:
times = batch_times.numpy()
data = batch_data.numpy()
mask = batch_mask.numpy()
env = batch_env.numpy()
env_times = batch_times.numpy()[:,0]
num_env = 29+19-N-5 # total variables - deficits - medications
sex_index = num_env-1
bmi_index = num_env-3
height_index = num_env-4
for sex in [0,1]:
selected = (env[:,sex_index] == sex)
size = np.sum(selected).astype(int)*batch_data.shape[1]
curr_times = times[selected].reshape(size)
curr_data = data[selected].reshape(size, N)
curr_mask = mask[selected].reshape(size, N)
for evid, ev in enumerate([height_index, bmi_index]):
avg_env[sex, 3:-4, evid] = binned_statistic(env_times[selected][env[selected, ev]>-100], env[selected][env[selected, ev]>-100, ev], bins = age_bins, statistic = np.std)[0][3:-4]
avg_env_smooth[sex, 3:-4, evid] = savgol_filter(avg_env[sex, 3:-4, evid], 9, 3)
nans, x = nan_helper(avg_env[sex, 3:-4, evid])
avg_env[sex, 3:-4, evid][nans] = np.interp(x(nans), x(~nans), avg_env[sex, 3:-4, evid][~nans])
avg_env_smooth[sex, 3:-4, evid] = savgol_filter(avg_env[sex, 3:-4, evid], 9, 3)
for n in range(N):
avg[sex, 3:-4,1+n] = binned_statistic(curr_times[curr_mask[:, n]>0], curr_data[curr_mask[:, n]>0,n], bins= age_bins, statistic = np.std)[0][3:-4]
nans, x= nan_helper(avg[sex, 3:-4,1+n])
avg[sex, 3:-4,1+n][nans]= np.interp(x(nans), x(~nans), avg[sex, 3:-4,1+n][~nans])
avg_smooth[sex, 3:-4,1+n] = savgol_filter(avg[sex, 3:-4,1+n], 9, 3)
for sex in [0, 1]:
avg[sex, :3] = avg[sex,3]
avg[sex,-4:] = avg[sex,-5]
avg_smooth[sex,:3] = avg_smooth[sex, 3]
avg_smooth[sex,-4:] = avg_smooth[sex,-5]
avg_env[sex,:3] = avg_env[sex,3]
avg_env[sex,-4:] = avg_env[sex,-5]
avg_env_smooth[sex,:3] = avg_env_smooth[sex,3]
avg_env_smooth[sex,-4:] = avg_env_smooth[sex,-5]
avg[sex, :,0] = bin_centers
avg_smooth[sex, :,0] = bin_centers
np.save(f'{dir}/Data/Population_std{postfix}.npy', avg_smooth)
np.save(f'{dir}/Data/Population_std_env{postfix}.npy', avg_env_smooth)