diff --git a/README.md b/README.md
index 115f9c7..1a27a7f 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,163 @@
-# SepsisLab
\ No newline at end of file
+# SepsisLab: Early Sepsis Prediction with Uncertainty Quantification and Active Sensing
+
+This repository contains the official PyTorch implementation of the following paper:
+
+> **Identifying Sepsis Subphenotypes via Time-Aware Multi-ModalAuto-Encoder (KDD2020)**
+> Changchang yin, Pin-Yu Chen, Bingsheng Yao, Dakuo Wang, Jeffrey Caterino, Ping Zhang
+> [paper]()
+>
+> **Abstract:** * Sepsis is the leading cause of in-hospital mortality in the USA. Early sepsis onset prediction and diagnosis could significantly improve the survival of sepsis patients. Existing predictive models are usually trained on high-quality data with few missing information, while missing values widely exist in real-world clinical scenarios (especially in the first hours of admissions to the hospital), which causes a significant decrease in accuracy and an increase in uncertainty for the predictive models. The common method to handle missing values is imputation, which replaces the unavailable variables with estimates from the observed data. The uncertainty of imputation results can be propagated to the sepsis prediction outputs, which have not been studied in existing works on either sepsis prediction or uncertainty quantification. In this study, we first define such propagated uncertainty as the variance of prediction output and then introduce uncertainty propagation methods to quantify the propagated uncertainty. Moreover, for the potential high-risk patients with low confidence due to limited observations, we propose a robust active sensing algorithm to increase confidence by actively recommending clinicians to observe the most informative variables. We validate the proposed models in both publicly available data (i.e., MIMIC-III and AmsterdamUMCdb) and proprietary data in The Ohio State University Wexner Medical Center (OSUWMC). The experimental results show that the propagated uncertainty is dominant at the beginning of admissions to hospitals and the proposed algorithm outperforms state-of-the-art active sensing methods. Finally, we implement a SepsisLab system for early sepsis prediction and active sensing based on our pre-trained models. Clinicians and potential sepsis patients can benefit from the system in early prediction and diagnosis of sepsis. *
+
+# Framework
+
+SepsisLab imputes missing values, makes sepsis predictions, compute the uncertainty propagated uncertainty from missing values and use active sensing to improve sepsis prediction results.
+
+
+Model framework. (A) The imputation model takes observed variables and corresponding timestamps as input,
+and generates the distribution of missing values. (B) Sepsis prediction model produces the patients’ sepsis onset risks with
+uncertainty based on the imputed data. (C) shows the uncertainty quantification method with Monte-Carlo sampling. (D)
+displays the uncertainty propagation method that can estimate propagated uncertainty by multiplying models’ gradient over
+imputed variables and the imputation uncertainty.
+
+
+Settings of sepsis onset prediction.
+
+
+
+User Interface of Our SepsisLab System. (A) Patient
+list with sepsis risk prediction score. (B) The patient’s de-
+mographics and the dashboard of the patient’s historical
+observations. (C) Predicted sepsis risk score with uncertainty
+range and recommended lab test items to observe.
+
+
+
+The Interactive Lab Test Recommendation Module in SepsisLab System.
+
+# Files Directory
+ SepsisLab
+ |
+ |--code
+ | |
+ | |--imputation * SepsisLab imputes the missing values and generate uncertainty of missing values
+ | |
+ | |--prediction * SepsisLab predicts sepsis risks and use active sensing to reduce propagated uncertainty
+ |
+ |--file * The preprocessing codes will write some files here.
+ |
+ |--data * Put the downloaded datasets here.
+ | |
+ | |--OSUWMC
+ | | |
+ | | |--train_groundtruth
+ | | |
+ | | |--sepsis_labels
+ | |
+ | |--MIMIC
+ | | |
+ | | |--train_groundtruth
+ | | |
+ | | |--sepsis_labels
+ | |
+ | |--AmsterdamUMCdb
+ | | |
+ | | |--train_groundtruth
+ | | |
+ | | |--sepsis_labels
+ |
+ |
+ |--result * The results of imputation, sepsis prediction and active sensing will be saved here.
+ | |--OSUWMC
+ | |
+ | |--MIMIC
+ | |
+ | |--AmsterdamUMCdb
+
+# Environment
+Ubuntu16.04, python3.8
+
+
+# Data preprocessing
+
+
+## MIMIC-III data preprocessing
+1. Download [MIMIC-III](https://mimic.physionet.org) dataset and put the data in TAME/data/MIMIC/initial\_mimiciii/.
+
+2. Generate pivoted files (pivoted\_lab.csv, pivoted\_vital.csv, pivoted\_sofa.csv) according to [MIT-LCP/mimic-code](https://github.com/MIT-LCP/mimic-code/blob/master/concepts/pivot/), and put the data in TAME/data/MIMIC/initial\_mimiciii/.
+
+- SQL for pivoted file generation can be found [here](https://github.com/yinchangchang/TAME/blob/master/code/preprocessing/pivoted_file_generation.md).
+
+3. Preprocess MIMIC-III data.
+```
+cd code/preprocessing
+python preprocess_mimic_data.py --dataset MIMIC
+python generate_sepsis_variables.py --dataset MIMIC
+python generate_value_distribution.py --dataset MIMIC
+```
+
+# Imputation
+
+1. Train imputation model.
+```
+cd code/imputation
+python main.py --dataset MIMIC
+```
+
+2. Generate the imputation results.
+```
+cd code/imputation
+python main.py --dataset MIMIC --phase test --resume ../../data/MIMIC/models/best.ckpt
+```
+
+## Results of imputation
+
+The RMSE imputation results on MIMIC dataset.
+```
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+Model | aniongap |bicarbonate | creatinine | chloride | glucose | hemoglobin | lactate | platelet | ptt | inr | pt | sodium | bun | wbc | spo2 | C-reactive | heartrate | hematocrit | sysbp | tempc | diasbp | gcs | resprate | bands | meanbp | Magnesium |urineoutput | Mean
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+Mean | 0.29 | 0.24 | 0.25 | 0.22 | 0.30 | 0.27 | 0.42 | 0.27 | 0.46 | 0.32 | 0.41 | 0.24 | 0.23 | 0.26 | 0.32 | 1.37 | 0.40 | 0.31 | 0.31 | 0.38 | 0.32 | 0.37 | 0.31 | 0.86 | 0.22 | 0.42 | 0.34 | 0.37
+KNN | 0.28 | 0.22 | 0.22 | 0.22 | 0.30 | 0.25 | 0.44 | 0.26 | 0.38 | 0.31 | 0.29 | 0.24 | 0.22 | 0.25 | 0.30 | 1.31 | 0.37 | 0.28 | 0.24 | 0.37 | 0.24 | 0.38 | 0.27 | 0.80 | 0.17 | 0.41 | 0.33 | 0.34
+3DMICe | 0.22 | 0.19 | 0.22 | 0.18 | 0.27 | 0.18 | 0.42 | 0.25 | 0.40 | 0.25 | 0.29 | 0.20 | 0.22 | 0.25 | 0.27 | 1.20 | 0.34 | 0.28 | 0.24 | 0.36 | 0.20 | 0.33 | 0.27 | 0.79 | 0.15 | 0.38 | 0.30 | 0.32
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+BRNN | 0.15 | 0.17 | 0.20 | 0.13 | 0.29 | 0.12 | 0.40 | 0.20 | 0.41 | 0.16 | 0.24 | 0.18 | 0.17 | 0.26 | 0.25 | 1.28 | 0.31 | 0.25 | 0.18 | 0.26 | 0.17 | 0.23 | 0.24 | 0.96 | 0.13 | 0.36 | 0.27 | 0.30
+CATSI | 0.12 | 0.12 | 0.22 | 0.13 | 0.29 | 0.14 | 0.41 | 0.22 | 0.42 | 0.20 | 0.25 | 0.18 | 0.20 | 0.23 | 0.25 | 1.13 | 0.34 | 0.25 | 0.18 | 0.24 | 0.16 | 0.22 | 0.24 | 0.85 | 0.13 | 0.33 | 0.27 | 0.29
+DETROIT | 0.11 | 0.09 | 0.28 | 0.09 | 0.27 | 0.13 | 0.38 | 0.22 | 0.46 | 0.17 | 0.24 | 0.10 | 0.17 | 0.22 | 0.26 | 1.10 | 0.31 | 0.26 | 0.18 | 0.24 | 0.16 | 0.23 | 0.24 | 0.78 | 0.13 | 0.33 | 0.25 | 0.27
+BRITS | 0.12 | 0.08 | 0.23 | 0.12 | 0.27 | 0.12 | 0.39 | 0.20 | 0.41 | 0.18 | 0.24 | 0.16 | 0.20 | 0.20 | 0.26 | 1.22 | 0.32 | 0.23 | 0.19 | 0.23 | 0.15 | 0.20 | 0.24 | 0.84 | 0.13 | 0.34 | 0.26 | 0.28
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+TAME-T | 0.13 | 0.11 | 0.24 | 0.10 | 0.25 | 0.11 | 0.34 | 0.19 | 0.36 | 0.21 | 0.24 | 0.11 | 0.17 | 0.20 | 0.24 | 1.29 | 0.24 | 0.17 | 0.15 | 0.25 | 0.13 | 0.16 | 0.23 | 0.71 | 0.12 | 0.30 | 0.24 | 0.26
+TAME-V | 0.16 | 0.13 | 0.23 | 0.12 | 0.26 | 0.11 | 0.36 | 0.20 | 0.38 | 0.19 | 0.22 | 0.14 | 0.17 | 0.20 | 0.23 | 1.29 | 0.23 | 0.18 | 0.13 | 0.21 | 0.12 | 0.15 | 0.22 | 0.70 | 0.11 | 0.31 | 0.24 | 0.26
+TAME-M | 0.13 | 0.11 | 0.24 | 0.10 | 0.25 | 0.11 | 0.34 | 0.19 | 0.36 | 0.21 | 0.24 | 0.11 | 0.17 | 0.23 | 0.24 | 1.33 | 0.24 | 0.20 | 0.14 | 0.25 | 0.14 | 0.16 | 0.24 | 0.73 | 0.14 | 0.32 | 0.25 | 0.27
+TAME | 0.11 | 0.09 | 0.19 | 0.08 | 0.26 | 0.09 | 0.35 | 0.18 | 0.38 | 0.15 | 0.20 | 0.10 | 0.14 | 0.21 | 0.22 | 1.16 | 0.23 | 0.19 | 0.13 | 0.24 | 0.12 | 0.16 | 0.20 | 0.73 | 0.12 | 0.31 | 0.23 | 0.25
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+```
+
+
+# Sepsis prediction
+
+Compute sepsis prediction.
+```
+cd code/prediction
+python main.py
+```
+
+# Active Sensing
+Use active sensing to reduce propagated uncertainty.
+```
+cd code/prediction
+python active_sensing.py
+```
+
+
+
+# Results
+## Uncertainty over different active sensing ratio
+
+
+
+
+## Sepsis onset prediction performance with different uncertainty
+
+
+
+
diff --git a/code/imputation/data_loader.py b/code/imputation/data_loader.py
new file mode 100644
index 0000000..01194c1
--- /dev/null
+++ b/code/imputation/data_loader.py
@@ -0,0 +1,616 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+
+import json
+import collections
+import os
+import random
+import time
+import warnings
+from copy import deepcopy
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+import sys
+sys.path.append('../tools')
+import py_op
+
+def find_index(v, vs, i=0, j=-1):
+ if j == -1:
+ j = len(vs) - 1
+
+ if v > vs[j]:
+ return j + 1
+ elif v < vs[i]:
+ return i
+ elif j - i == 1:
+ return j
+
+ k = int((i + j)/2)
+ if v <= vs[k]:
+ return find_index(v, vs, i, k)
+ else:
+ return find_index(v, vs, k, j)
+
+def add_time_gap(idata, odata, n = 10):
+ '''
+ delete lines with only meanbp_min, and urineoutput
+ '''
+ new_idata = []
+ new_odata = []
+ for iline,oline in zip(idata, odata):
+ vs = []
+ for v in oline.strip().split(','):
+ if v in ['', 'NA']:
+ vs.append(0)
+ else:
+ vs.append(1)
+ vs[0] = 0
+ vs[6] = 0
+ vs[8] = 0
+ if np.sum(vs) > 0:
+ new_idata.append(iline)
+ new_odata.append(oline)
+ return new_idata, new_odata
+
+
+class DataBowl(Dataset):
+ def __init__(self, args, files, phase='train'):
+ assert (phase == 'train' or phase == 'valid' or phase == 'test')
+ self.args = args
+ self.phase = phase
+ self.files = files
+
+ self.feature_mm_dict = py_op.myreadjson(os.path.join(args.file_dir, args.dataset + '_feature_mm_dict.json'))
+ self.feature_value_dict = py_op.myreadjson(os.path.join(args.file_dir, args.dataset + '_feature_value_dict_{:d}.json'.format(args.split_num)))
+ self.n_dd = 40
+ if args.dataset in ['MIMIC']:
+ self.ehr_list = py_op.myreadjson(os.path.join(args.data_dir, args.dataset, 'ehr_list.json' ))
+ self.ehr_id = { ehr: i+1 for i,ehr in enumerate(self.ehr_list) }
+ self.args.n_ehr = len(self.ehr_id) + 1
+
+ def map_input(self, value, feat_list, feat_index):
+
+ # for each feature (index), there are 1 embedding vectors for NA, split_num=100 embedding vectors for different values
+ index_start = (feat_index + 1)* (1 + self.args.split_num) + 1
+
+ if value in ['NA', '']:
+ if self.args.value_embedding == 'no':
+ return 0
+ return index_start - 1
+ else:
+ # print('""' + value + '""')
+ value = float(value)
+ if self.args.value_embedding == 'use_value':
+ minv, maxv = self.feature_mm_dict[feat_list[feat_index]]
+ v = (value - minv) / (maxv - minv + 10e-10)
+ # print(v, minv, maxv)
+ assert v >= 0
+ # map the value to its embedding index
+ v = int(self.args.split_num * v) + index_start
+ return v
+ elif self.args.value_embedding == 'use_order':
+ vs = self.feature_value_dict[feat_list[feat_index]][1:-1]
+ v = find_index(value, vs) + index_start
+ return v
+ elif self.args.value_embedding == 'no':
+ minv, maxv = self.feature_mm_dict[feat_list[feat_index]]
+ # v = (value - minv) / (maxv - minv)
+ v = (value - minv) / maxv + 1
+ v = int(v * self.args.split_num) / float(self.args.split_num)
+ return v
+
+ def map_output(self, value, feat_list, feat_index):
+ if value in ['NA', '']:
+ return 0
+ else:
+ value = float(value)
+ minv, maxv = self.feature_mm_dict[feat_list[feat_index]]
+ if maxv <= minv:
+ print(feat_list[feat_index], minv, maxv)
+ assert maxv > minv
+ v = (value - minv) / (maxv - minv)
+ # v = (value - minv) / (maxv - minv)
+ v = max(0, min(v, 1))
+ return v
+
+ def get_pre_info(self, input_data, iline, feat_list):
+ iline = len(input_data) - iline - 1
+ input_data = input_data[::-1][: -1] # the first line is head
+ pre_input, pre_time = self.get_post_info(input_data, iline, feat_list)
+ return pre_input, pre_time
+
+
+ def get_post_info(self, input_data, iline, feat_list):
+ input_data = input_data[iline:]
+ # input_data = [s.split(',') for s in input_data]
+ post_input = [0]
+ post_time = [0]
+ for i in range(1, len(input_data[0])):
+ value = ''
+ t = 0
+ for j in range(1, len(input_data)):
+ if input_data[j][i] not in ['NA', '']:
+ value = input_data[j][i]
+ t = abs(int(input_data[j][0]) - int(input_data[0][0]))
+ break
+ post_input.append(self.map_input(value, feat_list, i))
+ post_time.append(t)
+ return post_input, post_time
+
+
+
+ def get_mm_item(self, idx):
+ input_file = self.files[idx]
+ output_file = input_file.replace('with_missing', 'groundtruth')
+
+
+
+
+ with open(output_file) as f:
+ output_data = f.read().strip().split('\n')
+ with open(input_file) as f:
+ input_data = f.read().strip().split('\n')
+
+
+
+
+
+ if self.args.random_missing and self.phase == 'train':
+ input_data = []
+ valid_list = []
+ for line in output_data:
+ values = line.strip().split(',')
+ input_data.append(values)
+ valid = []
+ for iv,v in enumerate(values):
+ if v.strip() not in ['', 'NA']:
+ valid.append(1)
+ else:
+ valid.append(0)
+ valid_list.append(valid)
+ valid_list = np.array(valid_list)
+ valid_list[0] = 0
+ for i in range(1, valid_list.shape[1]):
+ valid = valid_list[:, i]
+ indices = np.where(valid > 0)[0]
+ np.random.shuffle(indices)
+ if len(indices>2):
+ input_data[indices[0]][i] = 'NA'
+ input_data = [','.join(line) for line in input_data]
+
+
+
+ init_input_data = input_data
+
+ if self.args.use_ve == 0:
+ input_data = self.pre_filling(input_data)
+
+
+ assert len(input_data) == len(output_data)
+
+ mask_list, input_list, output_list = [], [], []
+ pre_input_list, pre_time_list = [], []
+ post_input_list, post_time_list = [], []
+ input_split = [x.strip().split(',') for x in init_input_data]
+
+ for iline in range(len(input_data)):
+ inp = input_data[iline].strip()
+ oup = output_data[iline].strip()
+ init_inp = init_input_data[iline].strip()
+
+ if iline == 0:
+ feat_list = inp.split(',')
+ else:
+ in_vs = inp.split(',')
+ ou_vs = oup.split(',')
+ init_vs = init_inp.split(',')
+ ctime = int(inp.split(',')[0])
+ mask, input, output = [], [], []
+ rd = np.random.random(len(in_vs))
+ for i, (iv, ov, ir, init_iv) in enumerate(zip(in_vs, ou_vs, rd, init_vs)):
+ if ir < 0.7:
+ # iv = 'NA'
+ pass
+
+
+ if init_iv not in ['NA', '']:
+ mask.append(0)
+ elif ov not in ['NA', '']:
+ # print('err')
+ mask.append(1)
+ else:
+ mask.append(-1)
+ if self.args.use_ve:
+ input.append(self.map_input(iv, feat_list, i))
+ else:
+ input.append(self.map_output(iv, feat_list, i))
+ output.append(self.map_output(ov, feat_list, i))
+ mask_list.append(mask)
+ input_list.append(input)
+ output_list.append(output)
+ # pre and post info
+ pre_input, pre_time = self.get_pre_info(input_split, iline, feat_list)
+ pre_input_list.append(pre_input)
+ pre_time_list.append(pre_time)
+ post_input, post_time = self.get_post_info(input_split, iline, feat_list)
+ post_input_list.append(post_input)
+ post_time_list.append(post_time)
+
+ if len(mask_list) < self.args.n_visit:
+ for _ in range(self.args.n_visit - len(mask_list)):
+ # pad empty visit
+ mask = [-1 for _ in range(self.args.output_size + 1)]
+ vs = [0 for _ in range(self.args.output_size + 1)]
+ mask_list.append(mask)
+ input_list.append(vs)
+ output_list.append(vs)
+ pre_input_list.append(vs)
+ pre_time_list.append(vs)
+ post_input_list.append(vs)
+ post_time_list.append(vs)
+ # print(mask_list)
+ else:
+ mask_list = mask_list[: self.args.n_visit]
+ input_list = input_list[: self.args.n_visit]
+ output_list = output_list[: self.args.n_visit]
+ pre_input_list = pre_input_list[: self.args.n_visit]
+ pre_time_list = pre_time_list[: self.args.n_visit]
+ post_input_list = post_input_list[: self.args.n_visit]
+ post_time_list = post_time_list[: self.args.n_visit]
+
+
+
+
+
+
+
+ # print(mask_list)
+ mask_list = np.array(mask_list, dtype=np.int64)
+ output_list = np.array(output_list, dtype=np.float32)
+ pre_time_list = np.array(pre_time_list, dtype=np.int64)
+ post_time_list = np.array(post_time_list, dtype=np.int64)
+ if self.args.value_embedding == 'no' or self.args.use_ve == 0:
+ input_list = np.array(input_list, dtype=np.float32)
+ pre_input_list = np.array(pre_input_list, dtype=np.float32)
+ post_input_list = np.array(post_input_list, dtype=np.float32)
+ else:
+ input_list = np.array(input_list, dtype=np.int64)
+ pre_input_list = np.array(pre_input_list, dtype=np.int64)
+ post_input_list = np.array(post_input_list, dtype=np.int64)
+
+ input_list = input_list[:, 1:]
+ output_list = output_list[:, 1:]
+ mask_list = mask_list[:, 1:]
+ pre_input_list = pre_input_list[:, 1:]
+ pre_time_list = pre_time_list[:, 1:]
+ post_input_list = post_input_list[:, 1:]
+ post_time_list = post_time_list[:, 1:]
+
+ time_list = [x[0] for x in input_split][1:]
+ max_time = int(time_list[min(self.args.n_visit, len(time_list) - 1)]) + 1
+
+ if self.args.dataset in ['MIMIC']:
+ ehr_dict = py_op.myreadjson(os.path.join(input_file.replace('with_missing', 'groundtruth').replace('.csv', '.json')))
+ else:
+ ehr_dict = dict()
+ icd_list = [self.ehr_id[e] for e in ehr_dict.get('icd_demo', { }) if e in self.ehr_id]
+ icd_list = set(icd_list)
+ icd_list = set()
+ drug_dict = ehr_dict.get('drug', { })
+ visit_dict = dict()
+ for i in range(- 250, max_time + 1):
+ visit_dict[i] = sorted(icd_list)
+ for k, drug_list in drug_dict.items():
+ stime, etime = k.split(' -- ')
+ id_list = list(set([self.ehr_id[e] for e in drug_list if e in self.ehr_id]))
+ if len(id_list):
+ for t in range(max(0, int(stime)), int(etime)):
+ if t in visit_dict:
+ visit_dict[t] = visit_dict[t] + id_list
+ for k,v in visit_dict.items():
+ v = list(set(v))
+ visit_dict[k] = v
+ # if self.n_dd < len(v):
+ # self.n_dd = max(self.n_dd, len(v))
+ # print(self.n_dd)
+ dd_list = np.zeros((len(input_list), self.n_dd), dtype=np.int64)
+ for i,t in enumerate(time_list[: self.args.n_visit]):
+ if int(t) not in visit_dict:
+ continue
+ id_list = visit_dict[int(t)]
+ if len(id_list):
+ id_list = np.array(id_list, dtype=np.int64)
+ if len(id_list) > self.n_dd:
+ np.random.shuffle(id_list)
+ dd_list[i] = id_list[- self.n_dd:]
+ else:
+ dd_list[i][:len(id_list)] = id_list
+
+ # assert pre_time_list.max() < 256
+ # assert post_time_list.max() < 256
+ assert pre_time_list.min() >= 0
+ assert post_time_list.min() >= 0
+ pre_time_list[pre_time_list>200] = 200
+ post_time_list[post_time_list>200] = 200
+ assert len(mask_list[0]) == self.args.output_size
+ assert len(mask_list[0]) == len(pre_input_list[0])
+
+ # print(input_list.shape)
+ return torch.from_numpy(input_list), torch.from_numpy(output_list), torch.from_numpy(mask_list), input_file,\
+ torch.from_numpy(pre_input_list), torch.from_numpy(pre_time_list), torch.from_numpy(post_input_list), \
+ torch.from_numpy(post_time_list), torch.from_numpy(dd_list)
+
+ def pre_filling(self, input_data):
+ valid_list = []
+ input_value = []
+ for line in input_data:
+ values = line.strip().split(',')
+ input_value.append(values)
+ valid = []
+ for iv,v in enumerate(values):
+ if v.strip() not in ['', 'NA']:
+ valid.append(1)
+ else:
+ valid.append(0)
+ valid_list.append(valid)
+ valid_list = np.array(valid_list)
+ valid_list[0] = 0
+
+ pre_filled_data = [x[:1] for x in input_value]
+ pre_filled_data[0] = input_value[0]
+ for i in range(1, valid_list.shape[1]):
+ valid = valid_list[:, i]
+ indices = np.where(valid > 0)[0]
+ if len(indices):
+ mean = np.mean([float(input_value[id][i]) for id in indices])
+ first = indices[0]
+ else:
+ mean = 0
+ first = 10000
+
+ if self.args.model == 'mean':
+ value_list = self.feature_value_dict[self.args.name_list[i - 1]]
+ mean = value_list[int(len(value_list)/2)]
+
+ last_value = mean
+ for i_line in range(1, valid_list.shape[0]):
+ if valid_list[i_line, i]:
+ pre_filled_data[i_line].append(input_value[i_line][i])
+ last_value = input_value[i_line][i]
+ else:
+ pre_filled_data[i_line].append(str(last_value))
+ new_input_data = [','.join(x) for x in pre_filled_data]
+ return new_input_data
+
+
+
+ def get_brnn_item(self, idx):
+ input_file = self.files[idx]
+ output_file = input_file.replace('with_missing', 'groundtruth')
+
+ with open(output_file) as f:
+ output_data = f.read().strip().split('\n')
+ with open(input_file) as f:
+ input_data = f.read().strip().split('\n')
+
+
+ valid_list = []
+ input_value = []
+ for line in input_data:
+ values = line.strip().split(',')
+ input_value.append(values)
+ valid = []
+ for iv,v in enumerate(values):
+ if v.strip() not in ['', 'NA']:
+ valid.append(1)
+ else:
+ valid.append(0)
+ valid_list.append(valid)
+ valid_list = np.array(valid_list)
+ valid_list[0] = 0
+
+ pre_filled_data = [x[:1] for x in input_value]
+ pre_filled_data[0] = input_value[0]
+ for i in range(1, valid_list.shape[1]):
+ valid = valid_list[:, i]
+ indices = np.where(valid > 0)[0]
+ if len(indices):
+ # mean.append(np.mean([float(input_value[id][i]) for id in indices]))
+ # first.append(indices[0])
+ mean = np.mean([float(input_value[id][i]) for id in indices])
+ first = indices[0]
+ else:
+ mean = 0
+ first = 10000
+
+ if self.args.model in ['mean', 'mice']:
+ value_list = self.feature_value_dict[self.args.name_list[i - 1]]
+ mean = value_list[int(len(value_list)/2)]
+
+ last_value = mean
+ for i_line in range(1, valid_list.shape[0]):
+ if valid_list[i_line, i]:
+ pre_filled_data[i_line].append(input_value[i_line][i])
+
+ last_value = input_value[i_line][i]
+ next_value = last_value
+ for j in range(i_line+1, len(valid_list)):
+ if valid_list[j, i]:
+ next_value = input_value[j][i]
+ break
+ # mean = (float(last_value) + float(next_value)) / 2
+ mean = last_value
+ else:
+ pre_filled_data[i_line].append(mean)
+
+
+ input_list, output_list, mask_list = [], [], []
+ feat_list = input_data[0].strip().split(',')
+ for i_line in range(1, self.args.n_visit+1):
+ input = []
+ output = []
+ mask = []
+ if i_line >= len(input_data):
+ i_line = len(input_data) - 1
+ if i_line == len(input_data):
+ output_line = ['' for _ in output_line]
+ else:
+ input_line = input_data[i_line].strip().split(',')
+ output_line = output_data[i_line].strip().split(',')
+
+ assert len(valid_list) == len(input_data)
+
+ for i_feat in range(1, len(feat_list)):
+ iv = input_line[i_feat]
+ ov = output_line[i_feat]
+
+
+ if ov in ['NA', '']:
+ mask.append(-1)
+ output.append(0)
+ input.append(self.map_output(pre_filled_data[i_line][i_feat], feat_list, i_feat))
+ elif valid_list[i_line, i_feat]:
+ mask.append(0)
+ output.append(self.map_output(ov, feat_list, i_feat))
+ input.append(output[-1])
+ else:
+ mask.append(1)
+ output.append(self.map_output(ov, feat_list, i_feat))
+ input.append(self.map_output(pre_filled_data[i_line][i_feat], feat_list, i_feat))
+ input_list.append(input)
+ output_list.append(output)
+ mask_list.append(mask)
+ input_list = np.array(input_list, dtype=np.float32)
+ output_list = np.array(output_list, dtype=np.float32)
+ mask_list = np.array(mask_list, dtype=np.int64)
+ return torch.from_numpy(input_list), torch.from_numpy(output_list), torch.from_numpy(mask_list), input_file
+
+ def get_detroit_item(self, idx):
+ input_file = self.files[idx]
+ output_file = input_file.replace('with_missing', 'groundtruth')
+
+ with open(output_file) as f:
+ output_data = f.read().strip().split('\n')
+ with open(input_file) as f:
+ input_data = f.read().strip().split('\n')
+
+ valid_list = []
+ input_value = []
+ for line in input_data:
+ values = line.strip().split(',')
+ input_value.append(values)
+ valid = []
+ for iv,v in enumerate(values):
+ if v.strip() not in ['', 'NA']:
+ valid.append(1)
+ else:
+ valid.append(0)
+ valid_list.append(valid)
+ valid_list = np.array(valid_list)
+ valid_list[0] = 0
+
+ pre_filled_data = [x[:1] for x in input_value]
+ pre_filled_data[0] = input_value[0]
+ for i in range(1, valid_list.shape[1]):
+ valid = valid_list[:, i]
+ indices = np.where(valid > 0)[0]
+ if len(indices):
+ first = indices[0]
+ else:
+ first = -1
+
+ for i_line in range(1, valid_list.shape[0]):
+ if valid_list[i_line, i]:
+ pre_filled_data[i_line].append(input_value[i_line][i])
+ last_value = input_value[i_line][i]
+ elif first >= 0:
+ for ni in indices:
+ if ni > i_line:
+ break
+ next_value = input_value[ni][i]
+ if i_line < first:
+ pre_filled_data[i_line].append(next_value)
+ else:
+ try:
+ pre_filled_data[i_line].append((float(last_value)+ float(next_value))/ 2.0)
+ except:
+ print(indices)
+ print(first)
+ print(i_line)
+ pre_filled_data[i_line].append((float(last_value)+ float(next_value))/ 2.0)
+ else:
+ pre_filled_data[i_line].append(0)
+
+
+ input_list, output_list, mask_list = [], [], []
+ feat_list = input_data[0].strip().split(',')
+ for i_line in range(1, self.args.n_visit+1):
+ input = []
+ output = []
+ mask = []
+ if i_line >= len(input_data):
+ i_line = len(input_data) - 1
+ if i_line == len(input_data):
+ output_line = ['' for _ in output_line]
+ else:
+ input_line = input_data[i_line].strip().split(',')
+ output_line = output_data[i_line].strip().split(',')
+
+ assert len(valid_list) == len(input_data)
+
+ for i_feat in range(1, len(feat_list)):
+ iv = input_line[i_feat]
+ ov = output_line[i_feat]
+
+
+ if ov in ['NA', '']:
+ mask.append(-1)
+ output.append(0)
+ input.append(self.map_output(pre_filled_data[i_line][i_feat], feat_list, i_feat))
+ elif valid_list[i_line, i_feat]:
+ mask.append(0)
+ output.append(self.map_output(ov, feat_list, i_feat))
+ input.append(output[-1])
+ else:
+ mask.append(1)
+ output.append(self.map_output(ov, feat_list, i_feat))
+ input.append(self.map_output(pre_filled_data[i_line][i_feat], feat_list, i_feat))
+ input_list.append(input)
+ output_list.append(output)
+ mask_list.append(mask)
+
+ # last and next 2 visits
+ input_list = input_list[:2] + input_list + input_list[-2:]
+ new_input = []
+ for i in range(2, 2 + self.args.n_visit):
+ input = []
+ for j in range(i-2, i+3):
+ input = input + input_list[j]
+ new_input.append(input)
+ input_list = new_input
+
+ input_list = np.array(input_list, dtype=np.float32)
+ output_list = np.array(output_list, dtype=np.float32)
+ mask_list = np.array(mask_list, dtype=np.int64)
+ return torch.from_numpy(input_list), torch.from_numpy(output_list), torch.from_numpy(mask_list), input_file
+
+
+
+ def __getitem__(self, idx):
+ if self.args.model in ['brnn', 'brits', 'mean', 'mice']:
+ return self.get_brnn_item(idx)
+ elif self.args.model == 'detroit':
+ return self.get_detroit_item(idx)
+ else:
+ return self.get_mm_item(idx)
+
+ def __len__(self):
+ return len(self.files)
+
+
+
+
diff --git a/code/imputation/function.py b/code/imputation/function.py
new file mode 100644
index 0000000..e5adb79
--- /dev/null
+++ b/code/imputation/function.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+# coding=utf-8
+import numpy as np
+
+import os
+import torch
+from sklearn import metrics
+
+
+def compute_nRMSE(pred, label, mask):
+ '''
+ same as 3dmice
+ '''
+ assert pred.shape == label.shape == mask.shape
+
+ missing_indices = mask==1
+ missing_pred = pred[missing_indices]
+ missing_label = label[missing_indices]
+ missing_rmse = np.sqrt(((missing_pred - missing_label) ** 2).mean())
+
+ init_indices = mask==0
+ init_pred = pred[init_indices]
+ init_label = label[init_indices]
+ init_rmse = np.sqrt(((init_pred - init_label) ** 2).mean())
+
+ metric_list = [missing_rmse, init_rmse]
+ for i in range(pred.shape[2]):
+ apred = pred[:,:,i]
+ alabel = label[:,:, i]
+ amask = mask[:,:, i]
+
+ mrmse, irmse = [], []
+ for ip in range(len(apred)):
+ ipred = apred[ip]
+ ilabel = alabel[ip]
+ imask = amask[ip]
+
+ x = ilabel[imask>=0]
+ if len(x) == 0:
+ continue
+
+ minv = ilabel[imask>=0].min()
+ maxv = ilabel[imask>=0].max()
+ if maxv == minv:
+ continue
+
+ init_indices = imask==0
+ init_pred = ipred[init_indices]
+ init_label = ilabel[init_indices]
+
+ missing_indices = imask==1
+ missing_pred = ipred[missing_indices]
+ missing_label = ilabel[missing_indices]
+
+ assert len(init_label) + len(missing_label) >= 2
+
+ if len(init_pred) > 0:
+ init_rmse = np.sqrt((((init_pred - init_label) / (maxv - minv)) ** 2).mean())
+ irmse.append(init_rmse)
+
+ if len(missing_pred) > 0:
+ missing_rmse = np.sqrt((((missing_pred - missing_label)/ (maxv - minv)) ** 2).mean())
+ mrmse.append(missing_rmse)
+
+ metric_list.append(np.mean(mrmse))
+ metric_list.append(np.mean(irmse))
+
+ metric_list = np.array(metric_list)
+
+
+ metric_list[0] = np.mean(metric_list[2:][::2])
+ metric_list[1] = np.mean(metric_list[3:][::2])
+
+ return metric_list
+
+
+def save_model(p_dict, name='best.ckpt', folder=None):
+ args = p_dict['args']
+ if folder is None:
+ folder = os.path.join(args.data_dir, args.dataset, 'models')
+ # name = '{:s}-snm-{:d}-snr-{:d}-value-{:d}-trend-{:d}-cat-{:d}-lt-{:d}-size-{:d}-seed-{:d}-loss-{:s}-{:d}-{:s}'.format(args.task, args.split_num, args.split_nor, args.use_value, args.use_trend, args.use_cat, args.last_time, args.embed_size, args.seed, args.loss, args.time, name)
+ # name = '{:s}-{:s}-{:d}-variables-{:d}{:d}{:d}-{:s}'.format(args.dataset, args.model, len(args.name_list), args.use_ta, args.use_ve, args.use_mm, name)
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ model = p_dict['model']
+ state_dict = model.state_dict()
+ for key in state_dict.keys():
+ state_dict[key] = state_dict[key].cpu()
+ all_dict = {
+ 'epoch': p_dict['epoch'],
+ 'args': p_dict['args'],
+ 'best_metric': p_dict['best_metric'],
+ 'state_dict': state_dict
+ }
+ torch.save(all_dict, os.path.join(folder, name))
+
+def load_model(p_dict, model_file):
+ all_dict = torch.load(model_file)
+ p_dict['epoch'] = all_dict['epoch']
+ # p_dict['args'] = all_dict['args']
+ p_dict['best_metric'] = all_dict['best_metric']
+ # for k,v in all_dict['state_dict'].items():
+ # p_dict['model_dict'][k].load_state_dict(all_dict['state_dict'][k])
+ p_dict['model'].load_state_dict(all_dict['state_dict'])
+
+def compute_auc(labels, preds):
+ fpr, tpr, thr = metrics.roc_curve(labels, preds)
+ return metrics.auc(fpr, tpr)
diff --git a/code/imputation/main.py b/code/imputation/main.py
new file mode 100644
index 0000000..4720546
--- /dev/null
+++ b/code/imputation/main.py
@@ -0,0 +1,264 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+
+import sys
+reload(sys)
+sys.setdefaultencoding('utf8')
+
+import os
+import sys
+import time
+import numpy as np
+from sklearn import metrics
+import random
+import json
+from glob import glob
+from collections import OrderedDict
+from tqdm import tqdm
+
+
+import torch
+from torch.autograd import Variable
+from torch.backends import cudnn
+from torch.nn import DataParallel
+from torch.utils.data import DataLoader
+
+import data_loader
+from models import tame
+import myloss
+import function
+
+sys.path.append('../tools')
+import parse, py_op
+
+args = parse.args
+args.hidden_size = args.rnn_size = args.embed_size
+if torch.cuda.is_available():
+ args.gpu = 1
+else:
+ args.gpu = 0
+if args.model != 'tame':
+ args.use_ve = 0
+ args.use_mm = 0
+ args.use_ta = 0
+if args.use_ve == 0:
+ args.value_embedding = 'no'
+print 'epochs,', args.epochs
+
+def _cuda(tensor, is_tensor=True):
+ if args.gpu:
+ if is_tensor:
+ return tensor.cuda(async=True)
+ else:
+ return tensor.cuda()
+ else:
+ return tensor
+
+def get_lr(epoch):
+ lr = args.lr
+ return lr
+
+ if epoch <= args.epochs * 0.5:
+ lr = args.lr
+ elif epoch <= args.epochs * 0.75:
+ lr = 0.1 * args.lr
+ elif epoch <= args.epochs * 0.9:
+ lr = 0.01 * args.lr
+ else:
+ lr = 0.001 * args.lr
+ return lr
+
+def index_value(data):
+ '''
+ map data to index and value
+ '''
+ if args.use_ve == 0:
+ data = Variable(_cuda(data)) # [bs, 250]
+ return data
+ data = data.numpy()
+ index = data / (args.split_num + 1)
+ value = data % (args.split_num + 1)
+ index = Variable(_cuda(torch.from_numpy(index.astype(np.int64))))
+ value = Variable(_cuda(torch.from_numpy(value.astype(np.int64))))
+ return [index, value]
+
+def train_eval(data_loader, net, loss, epoch, optimizer, best_metric, phase='train'):
+ print(phase)
+ lr = get_lr(epoch)
+ if phase == 'train':
+ net.train()
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+ else:
+ net.eval()
+
+ loss_list, pred_list, label_list, mask_list = [], [], [], []
+ feature_mm_dict = py_op.myreadjson(os.path.join(args.file_dir, args.dataset + '_feature_mm_dict.json'))
+ for b, data_list in enumerate(tqdm(data_loader)):
+ data, label, mask, files = data_list[:4]
+ data = index_value(data)
+ if args.model == 'tame':
+ pre_input, pre_time, post_input, post_time, dd_list = data_list [4:9]
+ pre_input = index_value(pre_input)
+ post_input = index_value(post_input)
+ pre_time = Variable(_cuda(pre_time))
+ post_time = Variable(_cuda(post_time))
+ dd_list = Variable(_cuda(dd_list))
+ neib = [pre_input, pre_time, post_input, post_time]
+
+ label = Variable(_cuda(label)) # [bs, 1]
+ mask = Variable(_cuda(mask)) # [bs, 1]
+ if args.dataset in ['MIMIC'] and args.model == 'tame' and args.use_mm:
+ output = net(data, neib=neib, dd=dd_list, mask=mask) # [bs, 1]
+ elif args.model == 'tame' and args.use_ta:
+ output = net(data, neib=neib, mask=mask) # [bs, 1]
+ else:
+ output = net(data, mask=mask) # [bs, 1]
+
+ if phase == 'test':
+ folder = os.path.join(args.result_dir, args.dataset, 'imputation_result')
+ output_data = output.data.cpu().numpy()
+ mask_data = mask.data.cpu().numpy().max(2)
+ for (icu_data, icu_mask, icu_file) in zip(output_data, mask_data, files):
+ icu_file = os.path.join(folder, icu_file.split('/')[-1].replace('.csv', '.npy'))
+ np.save(icu_file, icu_data)
+ if args.dataset == 'MIMIC':
+ with open(os.path.join(args.data_dir, args.dataset, 'train_groundtruth', icu_file.split('/')[-1].replace('.npy', '.csv'))) as f:
+ init_data = f.read().strip().split('\n')
+ # print(icu_file)
+ wf = open(icu_file.replace('.npy', '.csv'), 'w')
+ wf.write(init_data[0] + '\n')
+ item_list = init_data[0].strip().split(',')
+ if len(init_data) <= args.n_visit:
+ try:
+ assert len(init_data) == (icu_mask >= 0).sum() + 1
+ except:
+ pass
+ # print(len(init_data))
+ # print(sum(icu_mask >= 0))
+ # print(icu_file)
+ for init_line, out_line in zip(init_data[1:], icu_data):
+ init_line = init_line.strip().split(',')
+ new_line = [init_line[0]]
+ # assert len(init_line) == len(out_line) + 1
+ for item, iv, ov in zip(item_list[1:], init_line[1:], out_line):
+ if iv.strip() not in ['', 'NA']:
+ new_line.append('{:4.4f}'.format(float(iv.strip())))
+ else:
+ minv, maxv = feature_mm_dict[item]
+ ov = ov * (maxv - minv) + minv
+ new_line.append('{:4.4f}'.format(ov))
+ new_line = ','.join(new_line)
+ wf.write(new_line + '\n')
+ wf.close()
+
+
+ loss_output = loss(output, label, mask)
+ pred_list.append(output.data.cpu().numpy())
+ loss_list.append(loss_output.data.cpu().numpy())
+ label_list.append(label.data.cpu().numpy())
+ mask_list.append(mask.data.cpu().numpy())
+
+ if phase == 'train':
+ optimizer.zero_grad()
+ loss_output.backward()
+ optimizer.step()
+
+
+ pred = np.concatenate(pred_list, 0)
+ label = np.concatenate(label_list, 0)
+ mask = np.concatenate(mask_list, 0)
+ metric_list = function.compute_nRMSE(pred, label, mask)
+ avg_loss = np.mean(loss_list)
+
+ print('\nTrain Epoch %03d (lr %.5f)' % (epoch, lr))
+ print('loss: {:3.4f} \t'.format(avg_loss))
+ print('metric: {:s}'.format('\t'.join(['{:3.4f}'.format(m) for m in metric_list[:2]])))
+
+
+ metric = metric_list[0]
+ if phase == 'valid' and (best_metric[0] == 0 or best_metric[0] > metric):
+ best_metric = [metric, epoch]
+ function.save_model({'args': args, 'model': net, 'epoch':epoch, 'best_metric': best_metric})
+ metric_list = metric_list[2:]
+ name_list = args.name_list
+ assert len(metric_list) == len(name_list) * 2
+ s = args.model
+ for i in range(len(metric_list)/2):
+ name = name_list[i] + ''.join(['.' for _ in range(40 - len(name_list[i]))])
+ print('{:s}{:3.4f}......{:3.4f}'.format(name, metric_list[2*i], metric_list[2*i+1]))
+ s = s+ ' {:3.4f}'.format(metric_list[2*i])
+ if phase != 'train':
+ print('\t\t\t\t best epoch: {:d} best MSE on missing value: {:3.4f} \t'.format(best_metric[1], best_metric[0]))
+ print(s)
+ return best_metric
+
+
+def main():
+
+ assert args.dataset in ['DACMI', 'MIMIC']
+ if args.dataset == 'MIMIC':
+ args.n_ehr = len(py_op.myreadjson(os.path.join(args.data_dir, args.dataset, 'ehr_list.json')))
+ args.name_list = py_op.myreadjson(os.path.join(args.file_dir, args.dataset+'_feature_list.json'))[1:]
+ args.output_size = len(args.name_list)
+ files = sorted(glob(os.path.join(args.data_dir, args.dataset, 'train_with_missing/*.csv')))
+ data_splits = py_op.myreadjson(os.path.join(args.file_dir, args.dataset + '_splits.json'))
+ train_files = [f for idx in [0, 1, 2, 3, 4, 5, 6] for f in data_splits[idx]]
+ valid_files = [f for idx in [7] for f in data_splits[idx]]
+ test_files = [f for idx in [8, 9] for f in data_splits[idx]]
+ if args.phase == 'test':
+ train_phase, valid_phase, test_phase, train_shuffle = 'test', 'test', 'test', False
+ else:
+ train_phase, valid_phase, test_phase, train_shuffle = 'train', 'valid', 'test', True
+ train_dataset = data_loader.DataBowl(args, train_files, phase=train_phase)
+ valid_dataset = data_loader.DataBowl(args, valid_files, phase=valid_phase)
+ test_dataset = data_loader.DataBowl(args, test_files, phase=test_phase)
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers, pin_memory=True)
+ valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
+ args.vocab_size = (args.output_size + 2) * (1 + args.split_num) + 5
+
+ if args.model == 'tame':
+ net = tame.AutoEncoder(args)
+ mloss = myloss.MSELoss(args)
+ gloss = myloss.GaussianLoss(args)
+
+ net = _cuda(net, 0)
+ mloss = _cuda(mloss, 0)
+ gloss = _cuda(gloss, 0)
+ loss = (mloss, gloss)
+
+ best_metric= [0,0]
+ start_epoch = 0
+
+ if args.resume:
+ p_dict = {'model': net}
+ function.load_model(p_dict, args.resume)
+ best_metric = p_dict['best_metric']
+ start_epoch = p_dict['epoch'] + 1
+
+ parameters_all = []
+ for p in net.parameters():
+ parameters_all.append(p)
+
+ optimizer = torch.optim.Adam(parameters_all, args.lr)
+
+ if args.phase == 'train':
+ for epoch in range(start_epoch, args.epochs):
+ print('start epoch :', epoch)
+ train_eval(train_loader, net, loss, epoch, optimizer, best_metric)
+ best_metric = train_eval(valid_loader, net, loss, epoch, optimizer, best_metric, phase='valid')
+ print 'best metric', best_metric
+
+ elif args.phase == 'test':
+ folder = os.path.join(args.result_dir, args.dataset, 'imputation_result')
+ os.system('rm -r ' + folder)
+ os.system('mkdir ' + folder)
+
+ train_eval(train_loader, net, loss, 0, optimizer, best_metric, 'test')
+ train_eval(valid_loader, net, loss, 0, optimizer, best_metric, 'test')
+ train_eval(test_loader, net, loss, 0, optimizer, best_metric, 'test')
+
+if __name__ == '__main__':
+ main()
diff --git a/code/imputation/models/__init__.py b/code/imputation/models/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/code/imputation/models/__init__.py
@@ -0,0 +1 @@
+
diff --git a/code/imputation/models/tame.py b/code/imputation/models/tame.py
new file mode 100644
index 0000000..cd52887
--- /dev/null
+++ b/code/imputation/models/tame.py
@@ -0,0 +1,247 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import json
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import *
+
+import numpy as np
+
+import sys
+sys.path.append('../tools')
+import parse, py_op
+
+def value_embedding_data(d = 512, split = 200):
+ vec = np.array([np.arange(split) * i for i in range(d/2)], dtype=np.float32).transpose()
+ vec = vec / vec.max()
+ embedding = np.concatenate((np.sin(vec), np.cos(vec)), 1)
+ embedding[0, :d] = 0
+ embedding = torch.from_numpy(embedding)
+ return embedding
+
+class AutoEncoder(nn.Module):
+ def __init__(self, args):
+ super(AutoEncoder, self).__init__()
+ self.args = args
+
+ if args.value_embedding == 'no':
+ self.embedding = nn.Linear(args.output_size, args.embed_size)
+ else:
+ self.embedding = nn.Embedding (args.vocab_size, args.embed_size )
+ self.lstm = nn.LSTM ( input_size=args.embed_size,
+ hidden_size=args.hidden_size,
+ num_layers=args.num_layers,
+ batch_first=True,
+ bidirectional=args.brnn)
+ if args.dataset in ['MIMIC']:
+ self.dd_embedding = nn.Embedding (args.n_ehr, args.embed_size )
+ self.value_embedding = nn.Embedding.from_pretrained(value_embedding_data(args.embed_size, args.split_num + 1))
+ self.value_mapping = nn.Sequential(
+ nn.Linear ( args.embed_size * 2, args.embed_size),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.1),
+ )
+ self.pre_embedding = nn.Sequential(
+ nn.Linear ( args.embed_size * 2, args.embed_size),
+ nn.ReLU ( ),
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ )
+ self.post_embedding = nn.Sequential(
+ nn.Linear ( args.embed_size * 2, args.embed_size),
+ nn.ReLU ( ),
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ )
+ self.pre_mapping = nn.Sequential(
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ )
+ self.post_mapping = nn.Sequential(
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ )
+ self.dd_mapping = nn.Sequential(
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ nn.Dropout(0.1),
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ nn.Dropout(0.1),
+ )
+ self.mapping = nn.Sequential(
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ )
+ # self.neib_embedding = nn.Sequential(
+ # nn.Linear ( args.embed_size * 2, args.embed_size),
+ # nn.ReLU ( ),
+ # nn.Dropout ( 0.1),
+ # )
+
+ self.embed_linear = nn.Sequential (
+ nn.Linear ( args.embed_size, args.embed_size),
+ nn.ReLU ( ),
+ # nn.Dropout ( 0.25 ),
+ # nn.Linear ( args.embed_size, args.embed_size),
+ # nn.ReLU ( ),
+ nn.Dropout ( 0.1),
+ )
+ self.relu = nn.ReLU ( )
+
+ lstm_size = args.rnn_size
+ if args.brnn:
+ lstm_size *= 2
+ hidden_size = args.hidden_size
+ self.tah_mapping = nn.Sequential (
+ nn.Linear(lstm_size, args.hidden_size),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.1),
+ nn.Linear ( args.hidden_size, hidden_size),
+ )
+ self.tav_mapping = nn.Sequential (
+ nn.Linear(args.hidden_size, args.hidden_size),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.1),
+ nn.Linear ( args.hidden_size, hidden_size),
+ )
+ self.output = nn.Sequential (
+ nn.Linear (lstm_size, args.rnn_size),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.1),
+ nn.Linear ( args.rnn_size, hidden_size),
+ )
+ self.value = nn.Sequential (
+ nn.Linear (hidden_size, hidden_size),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.1),
+ nn.Linear (hidden_size, args.output_size * 2),
+ )
+ self.pooling = nn.AdaptiveMaxPool1d(1)
+
+ def visit_pooling(self, x):
+ output = x
+ size = output.size()
+ output = output.view(size[0] * size[1], size[2], output.size(3)) # (64*30, 13, 512)
+ output = torch.transpose(output, 1,2).contiguous() # (64*30, 512, 13)
+ output = self.pooling(output) # (64*30, 512, 1)
+ output = output.view(size[0], size[1], size[3]) # (64, 30, 512)
+ return output
+
+ def value_order_embedding(self, x):
+ size = list(x[0].size()) # (64, 30, 13)
+ index, value = x
+ xi = self.embedding(index.view(-1)) # (64*30*13, 512)
+ # xi = xi * (value.view(-1).float() + 1.0 / self.args.split_num)
+ xv = self.value_embedding(value.view(-1)) # (64*30*13, 512)
+ x = torch.cat((xi, xv), 1) # (64*30*13, 1024)
+ x = self.value_mapping(x) # (64*30*13, 512)
+ size.append(-1)
+ x = x.view(size) # (64, 30, 13, 512)
+ return x
+
+ def pp_value_embedding(self, neib):
+ size = list(neib[1].size())
+ # print(type(neib[0]))
+ # print(len(neib[0]))
+ if self.args.use_ve == 0:
+ pre_x = self.embedding(neib[0])
+ post_x = self.embedding(neib[2])
+ pre_x = self.pre_mapping(pre_x)
+ post_x = self.post_mapping(post_x)
+ else:
+ pre_x = self.value_order_embedding(neib[0])
+ post_x = self.value_order_embedding(neib[2])
+
+ pre_t = self.value_embedding(neib[1].view(-1))
+ post_t = self.value_embedding(neib[3].view(-1))
+ size.append(-1)
+ pre_t = pre_t.view(size)
+ post_t = post_t.view(size)
+
+ pre_x = self.pre_embedding(torch.cat((pre_x, pre_t), 3))
+ post_x = self.post_embedding(torch.cat((post_x, post_t), 3))
+ return pre_x, post_x
+
+ def time_aware_attention(self, hidden, vectors):
+ # hidden [64, 30, 1024]
+ # vectors [64, 30, 54, 512]
+ wh = self.tah_mapping(hidden)
+ wh = wh.unsqueeze(2)
+ wh = wh.expand_as(vectors)
+ wv = self.tav_mapping(vectors)
+ beta = wh + wv # [64, 30, 54, 512]
+ alpha = F.softmax(beta, 2)
+
+ alpha = alpha.transpose(2,3).contiguous()
+ vectors = vectors.transpose(2,3).contiguous()
+ vsize = list(vectors.size()) # [64, 30, 512, 54]
+ # print(alpha.size())
+ # print(vectors.size())
+ alpha = alpha.view((-1, 1, alpha.size(3)))
+ vectors = vectors.view((-1, vectors.size(3), 1))
+ # print(alpha.size())
+ # print(vectors.size())
+ att_res = torch.bmm(alpha, vectors)
+ # print(att_res.size())
+ att_res = att_res.view(vsize[:3])
+ # print(att_res.size())
+ # return att_res
+ return att_res
+
+
+
+ def forward(self, x, neib=None, dd=None, mask=None):
+ # print()
+
+ # embedding
+ if self.args.value_embedding == 'no':
+ x = self.embedding( x ) # (64, 30, 512)
+ else:
+ if self.args.value_embedding == 'use_order':
+ x = self.value_order_embedding(x)
+ # x = self.neib_embedding(torch.cat((x, pre_x, post_x), 3))
+ else:
+ size = list(x.size()) # (64, 30, 13)
+ x = x.view(-1)
+ x = self.embedding( x ) # (64*30*13, 512)
+
+
+ # print(x.size())
+
+ if dd is not None:
+ dsize = list(dd.size()) + [-1]
+ d = self.dd_embedding(dd.view(-1)).view(dsize)
+ d = self.dd_mapping(d)
+ x = torch.cat((x, d), 2)
+ x = self.mapping(x)
+
+ x = self.visit_pooling(x) # (64, 30, 512)
+
+ # lstm
+ lstm_out, _ = self.lstm( x ) # (64, 30, 1024)
+ out = self.output(lstm_out)
+
+ if neib is not None and self.args.use_ta:
+ pre_x, post_x = self.pp_value_embedding(neib)
+ pp = torch.cat((pre_x, post_x), 2)
+ out = out + self.time_aware_attention(lstm_out, pp)
+ else:
+ pre_x, post_x = self.pp_value_embedding(neib)
+ pp = torch.cat((pre_x, post_x), 2)
+ pp = self.visit_pooling(pp)
+ out = out + pp
+
+ value = self.value(out)
+
+ return value
+
diff --git a/code/imputation/myloss.py b/code/imputation/myloss.py
new file mode 100644
index 0000000..e7a8219
--- /dev/null
+++ b/code/imputation/myloss.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import json
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import *
+import numpy as np
+
+def hard_mining(neg_output, neg_labels, num_hard, largest=True):
+ num_hard = min(max(num_hard, 10), len(neg_output))
+ _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)), largest=largest)
+ neg_output = torch.index_select(neg_output, 0, idcs)
+ neg_labels = torch.index_select(neg_labels, 0, idcs)
+ return neg_output, neg_labels
+
+
+class Loss(nn.Module):
+ def __init__(self):
+ super(Loss, self).__init__()
+ self.classify_loss = nn.BCELoss()
+
+ def forward(self, prob, labels, train=True):
+
+ pos_ind = labels > 0.5
+ neg_ind = labels < 0.5
+ pos_label = labels[pos_ind]
+ neg_label = labels[neg_ind]
+ pos_prob = prob[pos_ind]
+ neg_prob = prob[neg_ind]
+ pos_loss, neg_loss = 0, 0
+
+ # hard mining
+ num_hard_pos = 2
+ num_hard_neg = 6
+ if args.hard_mining:
+ pos_prob, pos_label= hard_mining(pos_prob, pos_label, num_hard_pos, largest=False)
+ neg_prob, neg_label= hard_mining(neg_prob, neg_label, num_hard_neg, largest=True)
+
+ if len(pos_prob):
+ pos_loss = 0.5 * self.classify_loss(pos_prob, pos_label)
+
+ if len(neg_prob):
+ neg_loss = 0.5 * self.classify_loss(neg_prob, neg_label)
+ classify_loss = pos_loss + neg_loss
+ # classify_loss = self.classify_loss(prob, labels)
+
+ # stati number
+ prob = prob.data.cpu().numpy() > 0.5
+ labels = labels.data.cpu().numpy()
+ pos_l = (labels==1).sum()
+ neg_l = (labels==0).sum()
+ pos_p = (prob + labels == 2).sum()
+ neg_p = (prob + labels == 0).sum()
+
+ return [classify_loss, pos_p, pos_l, neg_p, neg_l]
+
+
+class MSELoss(nn.Module):
+ def __init__(self, args):
+ super(MSELoss, self).__init__()
+ self.args = args
+ assert self.args.loss in ['missing', 'init', 'both']
+ self.mseloss = nn.MSELoss()
+
+ def forward(self, pred, label, mask):
+ pred = pred.view(-1)
+ label = label.view(-1)
+ mask = mask.view(-1)
+ assert len(pred) == len(label) == len(mask)
+
+ indices = mask==1
+ ipred = pred[indices]
+ ilabel = label[indices]
+ loss = self.mseloss(ipred, ilabel)
+
+ if self.args.loss == 'both':
+ indices = mask==0
+ ipred = pred[indices]
+ ilabel = label[indices]
+ loss += self.mseloss(ipred, ilabel) # * 0.1
+
+ # print('pred.shape', pred.size())
+ return loss
+
+
+class GaussianLoss(nn.Module):
+ def __init__(self, args):
+ super(MSELoss, self).__init__()
+ self.args = args
+ assert self.args.loss in ['missing', 'init', 'both']
+ self.mseloss = nn.MSELoss()
+
+ def forward(self, pred, std, label, mask):
+ pred = pred.view(-1)
+ label = label.view(-1)
+ std = std.view(-1)
+ mask = mask.view(-1)
+ assert len(pred) == len(label) == len(mask)
+
+ indices = mask==1
+ ipred = pred[indices]
+ ilabel = label[indices]
+ istd = std[indices]
+ loss = (ipred - ilabel) * (ipred - ilabel) / istd + istd / 2
+ return torch.sum(loss)/len(ilabel)
diff --git a/code/prediction/active_sensing.py b/code/prediction/active_sensing.py
new file mode 100644
index 0000000..75a4e53
--- /dev/null
+++ b/code/prediction/active_sensing.py
@@ -0,0 +1,192 @@
+# coding=utf8
+
+
+'''
+main.py 为程序入口
+'''
+
+
+# 基本依赖包
+import os
+import sys
+import time
+import json
+import traceback
+import numpy as np
+from glob import glob
+from tqdm import tqdm
+from tools import parse, py_op
+
+
+# torch
+import torch
+import torchvision
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+
+
+# 自定义文件
+import loss
+import models
+import function
+import loaddata
+# import framework
+from loaddata import dataloader
+from models import *
+from myloss import *
+
+
+# 全局变量
+args = parse.args
+args.hard_mining = 0
+args.gpu = 1
+args.use_trend = max(args.use_trend, args.use_value)
+args.use_value = max(args.use_trend, args.use_value)
+args.rnn_size = args.embed_size
+args.hidden_size = args.embed_size
+
+
+
+
+def train_eval(p_dict, phase='train'):
+ def active_sensing(data, model):
+ if args.gpu:
+ data = [ Variable(x.cuda()) for x in data ]
+ visits, values, stds, master, labels, times, observed_value = data
+ output = model(visits, master, times, phase, values)
+ gradient = compute_gradient(output)
+ uncertainty = compute_uncertainty(gradient, stds)
+ u_value = uncertainty.cpu().data.numpy()
+
+ if max(u_value) > args.uncertainty_threshold:
+ values = update_value(values, u_value, observed_value)
+ output = model(visits, master, times, phase, values)
+ function.compute_metric(output, labels, time, classification_loss_output, classification_metric_dict, phase)
+
+ ### 传入参数
+ epoch = p_dict['epoch']
+ model = p_dict['model'] # 模型
+ loss = p_dict['loss'] # loss 函数
+ if phase == 'train':
+ data_loader = p_dict['train_loader'] # 训练数据
+ optimizer = p_dict['optimizer'] # 优化器
+ else:
+ data_loader = p_dict['val_loader']
+
+ ### 局部变量定义
+ classification_metric_dict = dict()
+ # if args.task == 'case1':
+
+ for i,data in enumerate(tqdm(data_loader)):
+ active_sensing(data, model)
+
+ print('\nEpoch: {:d} \t Phase: {:s} \n'.format(epoch, phase))
+ metric = function.print_metric('classification', classification_metric_dict, phase)
+ if phase == 'val':
+ if metric > p_dict['best_metric'][0]:
+ p_dict['best_metric'] = [metric, epoch]
+ function.save_model(p_dict)
+
+ print('valid: metric: {:3.4f}\t epoch: {:d}\n'.format(metric, epoch))
+ print('\t\t\t valid: best_metric: {:3.4f}\t epoch: {:d}\n'.format(p_dict['best_metric'][0], p_dict['best_metric'][1]))
+ else:
+ print('train: metric: {:3.4f}\t epoch: {:d}\n'.format(metric, epoch))
+
+
+
+def main():
+ p_dict = dict() # All the parameters
+ p_dict['args'] = args
+ args.split_nn = args.split_num + args.split_nor * 3
+ args.vocab_size = args.split_nn * 145 + 1
+ print 'vocab_size', args.vocab_size
+
+ ### load data
+ print 'read data ...'
+ patient_time_record_dict = py_op.myreadjson(os.path.join(args.result_dir, 'patient_time_record_dict.json'))
+ patient_master_dict = py_op.myreadjson(os.path.join(args.result_dir, 'patient_master_dict.json'))
+ patient_label_dict = py_op.myreadjson(os.path.join(args.result_dir, 'patient_label_dict.json'))
+
+ patient_train = list(json.load(open(os.path.join(args.file_dir, args.task, 'train.json'))))
+ patient_valid = list(json.load(open(os.path.join(args.file_dir, args.task, 'val.json'))))
+
+ if len(patient_train) > len(patient_label_dict):
+ patients = patient_time_record_dict.keys()
+ patients = patient_label_dict.keys()
+ n = int(0.8 * len(patients))
+ patient_train = patients[:n]
+ patient_valid = patients[n:]
+
+
+
+
+
+ print 'data loading ...'
+ train_dataset = dataloader.DataSet(
+ patient_train,
+ patient_time_record_dict,
+ patient_label_dict,
+ patient_master_dict,
+ args=args,
+ phase='train')
+ train_loader = DataLoader(
+ dataset=train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ num_workers=8,
+ pin_memory=True)
+ val_dataset = dataloader.DataSet(
+ patient_valid,
+ patient_time_record_dict,
+ patient_label_dict,
+ patient_master_dict,
+ args=args,
+ phase='val')
+ val_loader = DataLoader(
+ dataset=val_dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=8,
+ pin_memory=True)
+
+ p_dict['train_loader'] = train_loader
+ p_dict['val_loader'] = val_loader
+
+
+
+ cudnn.benchmark = True
+ net = lstm.LSTM(args)
+ if args.gpu:
+ net = net.cuda()
+ p_dict['loss'] = loss.Loss().cuda()
+ else:
+ p_dict['loss'] = loss.Loss()
+
+ parameters = []
+ for p in net.parameters():
+ parameters.append(p)
+ optimizer = torch.optim.Adam(parameters, lr=args.lr)
+ p_dict['optimizer'] = optimizer
+ p_dict['model'] = net
+ start_epoch = 0
+ # args.epoch = start_epoch
+ # print ('best_f1score' + str(best_f1score))
+
+ p_dict['epoch'] = 0
+ p_dict['best_metric'] = [0, 0]
+
+
+ ### resume pretrained model
+ if os.path.exists(args.resume):
+ print 'resume from model ' + args.resume
+ function.load_model(p_dict, args.resume)
+ print 'best_metric', p_dict['best_metric']
+ p_dict['epoch'] = epoch
+ train_eval(p_dict, 'val')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/code/prediction/compare.py b/code/prediction/compare.py
new file mode 100644
index 0000000..fb72215
--- /dev/null
+++ b/code/prediction/compare.py
@@ -0,0 +1,6 @@
+import os
+
+for seed in [1, 40, 500, 100, 2019]:
+ for task in ['case1', 'task1', 'task2']:
+ cmd = 'python main.py --task {:s} --seed {:d}'.format(task, seed)
+ os.system(cmd)
diff --git a/code/prediction/function.py b/code/prediction/function.py
new file mode 100644
index 0000000..a03cd92
--- /dev/null
+++ b/code/prediction/function.py
@@ -0,0 +1,241 @@
+# coding=utf8
+#########################################################################
+# File Name: function.py
+# Author: ccyin
+# mail: ccyin04@gmail.com
+# Created Time: 2019年06月12日 星期三 14时28分43秒
+#########################################################################
+
+import os
+
+from sklearn import metrics
+import numpy as np
+
+import torch
+
+# file
+import loaddata
+from tools import parse
+# from loaddata import data_function
+
+args = parse.args
+
+def save_model(p_dict, name='best.ckpt', folder='../data/models/'):
+ args = p_dict['args']
+ name = '{:s}-snm-{:d}-snr-{:d}-value-{:d}-trend-{:d}-cat-{:d}-lt-{:d}-size-{:d}-seed-{:d}-{:s}'.format(args.task,
+ args.split_num, args.split_nor, args.use_value, args.use_trend,
+ args.use_cat, args.last_time, args.embed_size, args.seed, name)
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ model = p_dict['model']
+ state_dict = model.state_dict()
+ for key in state_dict.keys():
+ state_dict[key] = state_dict[key].cpu()
+ all_dict = {
+ 'epoch': p_dict['epoch'],
+ 'args': p_dict['args'],
+ 'best_metric': p_dict['best_metric'],
+ 'state_dict': state_dict
+ }
+ torch.save(all_dict, os.path.join(folder, name))
+
+def load_model(p_dict, model_file):
+ all_dict = torch.load(model_file)
+ p_dict['epoch'] = all_dict['epoch']
+ # p_dict['args'] = all_dict['args']
+ p_dict['best_metric'] = all_dict['best_metric']
+ # for k,v in all_dict['state_dict'].items():
+ # p_dict['model_dict'][k].load_state_dict(all_dict['state_dict'][k])
+ p_dict['model'].load_state_dict(all_dict['state_dict'])
+
+
+def save_segmentation_results(images, segmentations, folder='../data/middle_segmentation'):
+ stride = args.stride
+
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+
+ # images = images.data.cpu().numpy()
+ # segmentations = segmentations.data.cpu().numpy()
+ images = (images * 128) + 127
+ segmentations[segmentations>0] = 255
+ segmentations[segmentations<0] = 0
+
+ # print(images.shape, segmentations.shape)
+ for ii, image, seg in zip(range(len(images)), images, segmentations):
+ image = data_function.numpy_to_image(image)
+ new_seg = np.zeros([3, seg.shape[1] * stride, seg.shape[2] * stride])
+ for i in range(seg.shape[1]):
+ for j in range(seg.shape[2]):
+ for k in range(3):
+ new_seg[k, i*stride:(i+1)*stride, j*stride:(j+1)*stride] = seg[0,i,j]
+ seg = new_seg
+ seg = data_function.numpy_to_image(seg)
+ image.save(os.path.join(folder, str(ii) + '_image.png'))
+ seg.save(os.path.join(folder, str(ii) + '_seg.png'))
+
+
+def save_middle_results(data, folder = '../data/middle_images'):
+ stride = args.stride
+
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ numpy_data = [x.data.numpy() for x in data[1:]]
+ data = data[:1] + numpy_data
+ image_names, images, word_labels, seg_labels, bbox_labels, bbox_images = data[:6]
+ images = (images * 128) + 127
+ seg_labels = seg_labels*127 + 127
+
+
+ for ii, name, image, seg, bbox_image in zip(range(len(image_names)), image_names, images, seg_labels, bbox_images):
+ name = name.split('/')[-1]
+ image = data_function.numpy_to_image(image)
+ new_seg = np.zeros([3, seg.shape[1] * stride, seg.shape[2] * stride])
+ # print(seg[0].max(),seg[0].min())
+ for i in range(seg.shape[1]):
+ for j in range(seg.shape[2]):
+ for k in range(3):
+ new_seg[k, i*stride:(i+1)*stride, j*stride:(j+1)*stride] = seg[0,i,j]
+ seg = new_seg
+ seg = data_function.numpy_to_image(seg)
+ # image.save(os.path.join(folder, name))
+ # seg.save(os.path.join(folder, name.replace('image.png', 'seg.png')))
+ image.save(os.path.join(folder, str(ii) + '_image.png'))
+ seg.save(os.path.join(folder, str(ii) + '_seg.png'))
+
+ for ib,bimg in enumerate(bbox_image):
+ # print(bimg.max(), bimg.min(), bimg.dtype)
+ bimg = data_function.numpy_to_image(bimg)
+ bimg.save(os.path.join(folder, str(ii)+'_'+ str(ib) + '_bbox.png'))
+
+def save_detection_results(names, images, detect_character_output, folder='../data/test_results/'):
+ stride = args.stride
+
+ if not os.path.exists(folder):
+ os.mkdir(folder)
+ # images = images.data.cpu().numpy() # [bs, 3, w, h]
+ images = (images * 128) + 127
+ # detect_character_output = detect_character_output.data.cpu().numpy() # [bs, w, h, n_anchors, 5+class]
+
+ for i, name, image, bboxes in zip(range(len(names)), names, images, detect_character_output):
+ name = name.split('/')[-1]
+
+ ### 保存原图
+ # data_function.numpy_to_image(image).save(os.path.join(folder, name))
+ data_function.numpy_to_image(image).save(os.path.join(folder, str(i) + '_image.png'))
+
+ detected_bbox = detect_function.nms(bboxes)
+ # print([b[-1] for b in detected_bbox])
+ # print(len(detected_bbox))
+ image = data_function.add_bbox_to_image(image, detected_bbox)
+ # image.save(os.path.join(folder, name.replace('.png', '_bbox.png')))
+ image.save(os.path.join(folder, str(i) + '_bbox.png'))
+
+
+
+def compute_detection_metric(outputs, labels, loss_outputs,metric_dict):
+ loss_outputs[0] = loss_outputs[0].data
+ metric_dict['metric'] = metric_dict.get('metric', []) + [loss_outputs]
+
+def compute_segmentation_metric(outputs, labels, loss_outputs, metric_dict):
+ loss_outputs[0] = loss_outputs[0].data
+ metric_dict['metric'] = metric_dict.get('metric', []) + [loss_outputs]
+
+def compute_metric(outputs, labels, time, loss_outputs,metric_dict, phase='train'):
+ # loss_output_list, f1score_list, recall_list, precision_list):
+ if phase != 'test':
+ preds = outputs.data.cpu().numpy()
+ labels = labels.data.cpu().numpy()
+ else:
+ preds = np.array(outputs)
+
+ preds = preds.reshape(-1)
+ labels = labels.reshape(-1)
+
+ if time is not None:
+ time = time.reshape(-1)
+ assert preds.shape == time.shape
+ time = time[labels>-0.5]
+ assert preds.shape == labels.shape
+
+ preds = preds[labels>-0.5]
+ label = labels[labels>-0.5]
+
+ pred = preds > 0
+
+ assert len(pred) == len(label)
+
+ tp = (pred + label == 2).sum()
+ tn = (pred + label == 0).sum()
+ fp = (pred - label == 1).sum()
+ fn = (pred - label ==-1).sum()
+ fp = (pred - label == 1).sum()
+
+ metric_dict['tp'] = metric_dict.get('tp', 0.0) + tp
+ metric_dict['tn'] = metric_dict.get('tn', 0.0) + tn
+ metric_dict['fp'] = metric_dict.get('fp', 0.0) + fp
+ metric_dict['fn'] = metric_dict.get('fn', 0.0) + fn
+ loss = []
+ for x in loss_outputs:
+ if x == 0:
+ loss.append(x)
+ else:
+ loss.append(x.data.cpu().numpy())
+ # loss = [[x.data.cpu().numpy() for x in loss_outputs]]
+ metric_dict['loss'] = metric_dict.get('loss', []) + [loss]
+ if phase != 'train':
+ metric_dict['preds'] = metric_dict.get('preds', []) + list(preds)
+ metric_dict['labels'] = metric_dict.get('labels', []) + list(label)
+ if time is not None:
+ metric_dict['times'] = metric_dict.get('times', []) + list(time)
+
+def compute_metric_multi_classification(outputs, labels, loss_outputs, metric_dict):
+ preds = outputs.data.cpu().numpy() > 0
+ labels = labels.data.cpu().numpy()
+ for pred, label in zip(preds, labels):
+ pred = np.argmax(pred)
+ tp = (pred == label ).sum()
+ fn = (pred != label).sum()
+ accuracy = 1.0 * tp / (tp + fn)
+ metric_dict['accuracy'] = metric_dict.get('accuracy', []) + [accuracy]
+ metric_dict['loss'] = metric_dict.get('loss', []) + [[x.data.cpu().numpy() for x in loss_outputs]]
+
+
+def print_metric(first_line, metric_dict, phase='train'):
+ print(first_line)
+ loss_array = np.array(metric_dict['loss']).mean(0)
+ tp = metric_dict['tp']
+ tn = metric_dict['tn']
+ fp = metric_dict['fp']
+ fn = metric_dict['fn']
+ accuracy = 1.0 * (tp + tn) / (tp + tn + fp + fn)
+ recall = 1.0 * tp / (tp + fn + 10e-20)
+ precision = 1.0 * tp / (tp + fp + 10e-20)
+ f1score = 2.0 * recall * precision / (recall + precision + 10e-20)
+
+
+
+ loss_array = loss_array.reshape(-1)
+
+ print('loss: {:3.4f}\t pos loss: {:3.4f}\t negloss: {:3.4f}'.format(loss_array[0], loss_array[1], loss_array[2]))
+ print('accuracy: {:3.4f}\t f1score: {:3.4f}\t recall: {:3.4f}\t precision: {:3.4f}'.format(accuracy, f1score, recall, precision))
+ print('\n')
+
+ if phase != 'train':
+ fpr, tpr, thr = metrics.roc_curve(metric_dict['labels'], metric_dict['preds'])
+ return metrics.auc(fpr, tpr)
+ else:
+ return f1score
+
+def load_all():
+ fo = '../data/models'
+ pre = ''
+ for fi in sorted(os.listdir(fo)):
+ if fi[:5] != pre:
+ print
+ pre = fi[:5]
+ x = torch.load(os.path.join(fo, fi))
+ # print x['epoch'], fi
+ print x['best_metric'], fi
+load_all()
+
diff --git a/code/prediction/loaddata/__init__.py b/code/prediction/loaddata/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/code/prediction/loaddata/data_function.py b/code/prediction/loaddata/data_function.py
new file mode 100644
index 0000000..bbbbbcc
--- /dev/null
+++ b/code/prediction/loaddata/data_function.py
@@ -0,0 +1,201 @@
+# coding=utf8
+#########################################################################
+# File Name: data_function.py
+# Author: ccyin
+# mail: ccyin04@gmail.com
+# Created Time: 2019年06月12日 星期三 11时28分13秒
+#########################################################################
+
+import os
+import sys
+import time
+import json
+import numpy as np
+from PIL import Image,ImageDraw,ImageFont,ImageFilter
+
+from tools import parse
+args = parse.args
+
+def add_text_to_img(img, text, size, font, color, place):
+ imgdraw = ImageDraw.Draw(img)
+ imgfont = ImageFont.truetype(font,size=size)
+ imgdraw.text(place, text, fill=color, font=imgfont)
+ return img
+
+def image_to_numpy(image):
+ image = np.array(image)
+ image = image.transpose(2, 0, 1)
+ return image
+
+def numpy_to_image(image):
+ image = image.transpose(1, 2, 0).astype(np.uint8)
+ return Image.fromarray(image)
+
+def add_line(bbox_image, bbox, gray=128, proposal=0):
+
+ # print(bbox, bbox_image.shape)
+
+ sx,sy,ex,ey = bbox[:4]
+ _,x,y = bbox_image.shape # 3, 64, 512
+
+ if not proposal:
+ assert sx <= x
+ assert ex <= x
+ assert sy <= y
+ assert ey <= y
+
+ n = 2
+ bbox_image[:, sx:ex, sy-n:sy+n] = gray
+ bbox_image[:, sx:ex, ey-n:ey+n] = gray
+ bbox_image[:, sx-n:sx+n, sy:ey] = gray
+ bbox_image[:, ex-n:ex+n, sy:ey] = gray
+ return bbox_image
+
+def add_bbox_to_image(image, detected_bbox):
+ words = args.words
+
+ image = np.zeros_like(image) + 255
+ image = numpy_to_image(image)
+ for bbox in detected_bbox:
+ bbox = [int(x) for x in bbox[1:]]
+ # size = int((bbox[2] + bbox[3] - bbox[0] - bbox[0]) / 2)
+ size = 16
+ place = (int(bbox[1]/2 + bbox[3]/2), int(bbox[0]/2+bbox[2]/2))
+ image = add_text_to_img(image, words[bbox[-1]], size, '../files/ttf/simsun.ttf', (0,0,0), place)
+ return image
+
+def test_label(image_file, seg_file, bbox_file, save_folder):
+ if not os.path.exists(save_folder):
+ os.mkdir(save_folder)
+ image = Image.open(image_file).convert('RGB')
+ seg = Image.open(seg_file)
+ image.save(os.path.join(save_folder, '_image.png'))
+ seg.save(os.path.join(save_folder, '_seg.png'))
+
+ bbox_image = image_to_numpy(image)
+ bbox_label = json.load(open(bbox_file))
+ for bbox in bbox_label:
+ bbox_image = add_line(bbox_image, bbox)
+ image = numpy_to_image(bbox_image)
+ image.save(os.path.join(save_folder, '_bbox.png'))
+
+def generate_bbox_seg(image, font_place, font_size, font_list):
+ '''
+ 只生成框位置坐标
+ '''
+ imgh,imgw = image.size
+ font_num = len(font_list)
+
+ # 生成分割label
+ seg_label = np.zeros((3, image.size[1], image.size[0]), dtype=np.uint8) + 255
+ sy = font_place[0]
+ ey = sy + font_size * font_num
+ sx = font_place[1]
+ ex = sx + font_size
+ seg_label[:, sx:ex, sy:ey] = 128
+ # seg_label = seg_label.transpose((1,0,2))
+ # seg_label = Image.fromarray(seg_label)
+ seg_label = numpy_to_image(seg_label)
+
+ # 生成bbox label
+ bbox_label = []
+ for i, font in enumerate(font_list):
+ sx = font_place[0] + font_size * i
+ ex = sx + font_size
+ sy = font_place[1]
+ ey = sy + font_size
+ bbox_label.append([sy,sx,ey,ex,font])
+
+ # 生成bbox_image
+ # bbox_image = np.zeros((3, image.size[0], image.size[1]), dtype=np.uint8) + 255
+ bbox_image = image_to_numpy(image)
+ for bbox in bbox_label:
+ bbox_image = add_line(bbox_image, bbox)
+ bbox_image = numpy_to_image(bbox_image)
+
+
+ return bbox_label, seg_label, bbox_image
+
+
+def generate_bbox_label(image, font_place, font_size, font_num, args, image_size):
+ '''
+ 根据anchors生成监督信息
+ '''
+ imgh,imgw = image.size
+ seg_label = np.zeros((int(image_size[0]/2), int(image_size[1]/2)), dtype=np.float32)
+ sx = float(font_place[0]) / image.size[0] * image_size[0]
+ ex = sx + float(font_size) / image.size[0] * image_size[0] * font_num
+ sy = float(font_place[1]) / image.size[1] * image_size[1]
+ ey = sy + float(font_size) / image.size[1] * image_size[1]
+ seg_label[int(sx/2):int(ex/2), int(sy/2):int(ey/2)] = 1
+ seg_label = seg_label.transpose((1,0))
+
+ bbox_label = np.zeros((
+ int(image_size[0]/args.stride), # 16
+ int(image_size[1]/args.stride), # 16
+ len(args.anchors), # 4
+ 4 # dx,dy,dd,c
+ ), dtype=np.float32)
+ fonts= []
+ for i in range(font_num):
+ x = font_place[0] + font_size/2. + i * font_size
+ y = font_place[1] + font_size/2.
+ h = font_size
+ w = font_size
+
+ x = float(x) * image_size[0] / imgh
+ h = float(h) * image_size[0] / imgh
+ y = float(y) * image_size[1] / imgw
+ w = float(w) * image_size[1] / imgw
+ fonts.append([x,y,h,w])
+
+ # print bbox_label.shape
+ for ix in range(bbox_label.shape[0]):
+ for iy in range(bbox_label.shape[1]):
+ for ia in range(bbox_label.shape[2]):
+ proposal = [ix*args.stride + args.stride/2, iy*args.stride + args.stride/2, args.anchors[ia]]
+ iou_fi = []
+ for fi, font in enumerate(fonts):
+ iou = comput_iou(font, proposal)
+ iou_fi.append((iou, fi))
+ max_iou, max_fi = sorted(iou_fi)[-1]
+ if max_iou > 0.5:
+ # 正例
+ dx = (font[0] - proposal[0]) / float(proposal[2])
+ dy = (font[1] - proposal[1]) / float(proposal[2])
+ fd = max(font[2:])
+ dd = np.log(fd / float(proposal[2]))
+ # bbox_label[ix,iy,ia] = [dx, dy, dd, 1]
+ bbox_label[ix,iy,ia] = [dx, dy, dd, 1]
+ elif max_iou > 0.25:
+ # 忽略
+ bbox_label[ix,iy,ia,3] = 0
+ else:
+ # 负例
+ bbox_label[ix,iy,ia,3] = -1
+ # 这里有一个transpose操作
+ bbox_label = bbox_label.transpose((1,0,2,3))
+
+
+ # 计算anchor信息
+ return bbox_label, seg_label
+
+def augment(image, seg, bbox, label):
+ return image, seg, bbox, label
+
+def random_select_indices(indices, n=10):
+ indices = np.array(indices)
+ # print('initial shape', indices.shape)
+ indices = indices.transpose(1,0)
+ # print('change shape', indices.shape)
+ np.random.shuffle(indices)
+ indices = indices[:n]
+ # print('select ', indices.shape)
+ indices = indices.transpose(1,0)
+ # print('change shape', indices.shape)
+ # indices = tuple(indices)
+ return tuple(indices)
+
+
+
+# test_label( '../../data/generated_images/1.png', '../../data/generated_images/1_seg.png', '../../data/generated_images/1_bbox.json', '../../data/test/')
diff --git a/code/prediction/loaddata/dataloader.py b/code/prediction/loaddata/dataloader.py
new file mode 100644
index 0000000..996ca62
--- /dev/null
+++ b/code/prediction/loaddata/dataloader.py
@@ -0,0 +1,189 @@
+# encoding: utf-8
+
+"""
+Read images and corresponding labels.
+"""
+
+import numpy as np
+import os
+import sys
+import json
+import torch
+from torch.utils.data import Dataset
+
+sys.path.append('loaddata')
+import data_function
+
+
+class DataSet(Dataset):
+ def __init__(self,
+ patient_list,
+ patient_time_record_dict,
+ patient_label_dict,
+ patient_master_dict,
+ phase='train', # phase
+ split_num=5, # split feature value into different parts
+ args=None # 全局参数
+ ):
+
+ self.patient_list = patient_list
+ self.patient_time_record_dict = patient_time_record_dict
+ self.patient_label_dict = patient_label_dict
+ self.patient_master_dict = patient_master_dict
+ self.phase = phase
+ self.split_num = split_num
+ self.split_nor = args.split_nor
+ self.split_nn = args.split_nn
+ self.args = args
+ if args.task == 'task2':
+ self.length = 49
+ else:
+ self.length = 98
+
+
+ def get_visit_info(self, time_record_dict):
+ # times = sorted([float(t) for t in time_record_dict.keys()])
+ times = sorted(time_record_dict.keys(), key=lambda s:float(s))
+ # for t in time_record_dict:
+ # time_record_dict[str(float(t))] = time_record_dict[t]
+ visit_list = []
+ value_list = []
+ mask_list = []
+ time_list = []
+
+ n_code = 72
+ import traceback
+
+ # trend
+ trend_list = []
+ previous_value = [[[],[]] for _ in range(143)]
+ change_th = 0.02
+ start_time = - self.args.avg_time * 2
+ end_time = -1
+
+ for time in times :
+ if float(time) <= -4 - self.length:
+ continue
+ if self.args.task == 'task2':
+ if float(time) > self.args.last_time:
+ continue
+ time = str(time)
+ records = time_record_dict[time]
+ feature_index = [r[0] for r in records]
+ feature_value = [float(r[1]) for r in records]
+
+ # embed feature value
+ feature_index = np.array(feature_index)
+ feature_value = np.array(feature_value)
+ feature = feature_index * self.split_nn + feature_value * self.split_num
+
+ # trend
+ trend = np.zeros(n_code, dtype=np.int64)
+ i_v = 0
+ for idx, val in zip(feature_index, feature_value):
+ # delete val with time less than start_time
+ ptimes = previous_value[idx][0]
+ lip = 0
+ for ip, pt in enumerate(ptimes):
+ if pt >= float(time) + start_time:
+ lip = ip
+ break
+
+ avg_val = None
+ if len(previous_value[idx][0]) == 1:
+ avg_val = previous_value[idx][1][-1]
+
+ previous_value[idx] = [
+ previous_value[idx][0][lip:],
+ previous_value[idx][1][lip:]]
+
+ # trend value
+ if len(previous_value[idx][0]):
+ avg_val = np.mean(previous_value[idx][1])
+ if avg_val is not None:
+ if val < avg_val - change_th:
+ delta = 0
+ elif val > avg_val + change_th:
+ delta = 1
+ else:
+ delta = 2
+ trend[i_v] = idx * 3 + delta + 1
+
+ # add new val
+ previous_value[idx][0].append(float(time))
+ previous_value[idx][1].append(float(val))
+
+ i_v += 1
+
+
+
+
+
+ visit = np.zeros(n_code, dtype=np.int64)
+ mask = np.zeros(n_code, dtype=np.int64)
+ i_v = 0
+ for feat, idx, val in zip(feature, feature_index, feature_value):
+
+ # order
+ mask[i_v] = 1
+ visit[i_v] = int(feat + 1)
+ i_v += 1
+
+
+
+
+ value = np.zeros((2, n_code ), dtype=np.int64)
+ value[0][: len(feature_index)] = feature_index + 1
+ value[1][: len(feature_index)] = (feature_value * 100).astype(np.int64)
+ value_list.append(value)
+
+ visit_list.append(visit)
+ mask_list.append(mask)
+ time_list.append(float(time))
+ trend_list.append(trend)
+
+ if self.args.task == 'task2':
+ num_len = self.length + self.args.last_time
+ # print 'task2', num_len, self.args.last_time
+ else:
+ num_len = self.length
+ # print 'task1'
+ # print 'num_len', num_len
+ # print len(visit_list)
+ assert len(visit_list) <= num_len
+ visit = np.zeros(n_code, dtype=np.int64)
+ trend = np.zeros(n_code, dtype=np.int64)
+ value = np.zeros((2, n_code), dtype=np.int64)
+ while len(visit_list) < num_len:
+ visit_list.append(visit)
+ value_list.append(value)
+ mask_list.append(visit)
+ time_list.append(0)
+ trend_list.append(trend)
+
+ return np.array(visit_list), np.array(value_list), np.array(mask_list, dtype=np.float32), np.array(time_list, dtype=np.float32), np.array(trend_list)
+
+
+
+
+ def __getitem__(self, index):
+ patient = self.patient_list[index]
+ if self.args.use_visit:
+ visit_list, value_list, mask_list, time_list, trend_list= self.get_visit_info(self.patient_time_record_dict[patient])
+ master = self.patient_master_dict[patient]
+ master = [int(m) for m in master]
+ master = np.float32(master)
+ if self.args.final == 1:
+ label = np.float32(0)
+ else:
+ label = np.float32(self.patient_label_dict[patient])
+ if self.phase == 'test':
+ return visit_list, value_list, mask_list, master, label, time_list, trend_list, patient
+ else:
+ return visit_list, value_list, mask_list, master, label, time_list, trend_list
+
+
+
+
+ def __len__(self):
+ return len(self.patient_list)
diff --git a/code/prediction/loss.py b/code/prediction/loss.py
new file mode 100644
index 0000000..954df10
--- /dev/null
+++ b/code/prediction/loss.py
@@ -0,0 +1,97 @@
+# coding=utf8
+#########################################################################
+# File Name: classify_loss.py
+# Author: ccyin
+# mail: ccyin04@gmail.com
+# Created Time: 2019年06月11日 星期二 14时33分59秒
+#########################################################################
+
+import sys
+sys.path.append('classification/')
+
+# torch
+import numpy as np
+import torch
+import torchvision
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+
+
+class Loss(nn.Module):
+ def __init__(self):
+ super(Loss, self).__init__()
+ self.classify_loss = nn.BCELoss()
+ self.sigmoid = nn.Sigmoid()
+ self.regress_loss = nn.SmoothL1Loss()
+
+ def forward(self, font_output, font_target, use_hard_mining=False):
+ batch_size = font_output.size(0)
+ # font_output = font_output.cpu()
+ # font_target = font_target.cpu()
+
+
+
+ # font_target = font_target.unsqueeze(-1).expand(font_output.size()).contiguous()
+ font_output = self.sigmoid(font_output)
+ # font_loss = self.classify_loss(font_output, font_target)
+ # return [font_loss, font_loss, font_loss]
+
+
+ font_output = font_output.view(-1)
+ font_target = font_target.view(-1)
+ pos_index = font_target == 1
+ neg_index = font_target == 0
+
+ assert font_output.size() == font_target.size()
+ assert pos_index.size() == font_target.size()
+ assert neg_index.size() == font_target.size()
+
+ # print font_output.size(), font_target.size()
+
+
+ # pos
+ # print pos_index.dtype
+ # print pos_index.size()
+ # print pos_index
+ pos_target = font_target[pos_index]
+ pos_output = font_output[pos_index]
+ # pos_output = font_output.cpu()[pos_index.cpu()].cuda()
+ # pos_target = font_target.cpu()[pos_index.cpu()].cuda()
+ if use_hard_mining:
+ num_hard_pos = max(2, int(0.2 * batch_size))
+ if len(pos_output) > num_hard_pos:
+ pos_output, pos_target = hard_mining(pos_output, pos_target, num_hard_pos, largest=False, start=int(num_hard_pos/4))
+ if len(pos_output):
+ pos_loss = self.classify_loss(pos_output, pos_target) * 0.5
+ else:
+ pos_loss = 0
+
+
+ # neg
+ neg_output = font_output[neg_index]
+ neg_target = font_target[neg_index]
+ if use_hard_mining:
+ num_hard_neg = max(num_hard_pos, 2)
+ if len(neg_output) > num_hard_neg:
+ neg_output, neg_target = hard_mining(neg_output, neg_target, num_hard_neg, largest=True, start=int(num_hard_pos/4))
+ if len(neg_output):
+ neg_loss = self.classify_loss(neg_output, neg_target) * 0.5
+ else:
+ neg_loss = 0
+
+ font_loss = pos_loss + neg_loss
+ return [font_loss, pos_loss, neg_loss]
+ # return [font_loss.cuda(), pos_loss, neg_loss]
+
+
+def hard_mining(neg_output, neg_labels, num_hard, largest=True, start=0):
+ # num_hard = min(max(num_hard, 10), len(neg_output))
+ _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)), largest=largest)
+ start = 0
+ idcs = idcs[start:]
+ neg_output = torch.index_select(neg_output, 0, idcs)
+ neg_labels = torch.index_select(neg_labels, 0, idcs)
+ return neg_output, neg_labels
diff --git a/code/prediction/models/__init__.py b/code/prediction/models/__init__.py
new file mode 100644
index 0000000..81a3235
--- /dev/null
+++ b/code/prediction/models/__init__.py
@@ -0,0 +1,7 @@
+# coding=utf8
+#########################################################################
+# File Name: models/__init__.py
+# Author: ccyin
+# mail: ccyin04@gmail.com
+# Created Time: Mon 19 Aug 2019 02:16:51 AM CST
+#########################################################################
diff --git a/code/prediction/models/lstm.py b/code/prediction/models/lstm.py
new file mode 100644
index 0000000..f4574a7
--- /dev/null
+++ b/code/prediction/models/lstm.py
@@ -0,0 +1,232 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import json
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import *
+
+import numpy as np
+
+import sys
+sys.path.append('tools')
+import parse, py_op
+args = parse.args
+
+
+def time_encoding_data(d = 512, time = 200):
+ vec = np.array([np.arange(time) * i for i in range(d/2)], dtype=np.float32).transpose()
+ vec = vec / vec.max() / 2
+ encoding = np.concatenate((np.sin(vec), np.cos(vec)), 1)
+ encoding = torch.from_numpy(encoding)
+ return encoding
+
+
+class LSTM(nn.Module):
+ def __init__(self, opt):
+ super ( LSTM, self ).__init__ ( )
+ self.use_cat = args.use_cat
+ self.avg_time = args.avg_time
+
+ self.embedding = nn.Embedding (opt.vocab_size, opt.embed_size )
+ self.lstm = nn.LSTM ( input_size=opt.embed_size,
+ hidden_size=opt.hidden_size,
+ num_layers=opt.num_layers,
+ batch_first=True,
+ bidirectional=True)
+
+ self.linear_embed = nn.Sequential (
+ nn.Linear ( opt.embed_size, opt.embed_size ),
+ nn.ReLU ( ),
+ nn.Linear ( opt.embed_size, opt.embed_size ),
+ )
+ self.tv_mapping = nn.Sequential (
+ nn.Linear ( opt.embed_size , opt.embed_size / 2),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.25 ),
+ nn.Linear ( opt.embed_size / 2, opt.embed_size ),
+ )
+ self.alpha = nn.Linear(args.embed_size, 1)
+
+
+ no = 1
+ if self.use_cat:
+ no += 1
+ self.output_time = nn.Sequential (
+ nn.Linear(opt.embed_size * no, opt.embed_size),
+ nn.ReLU ( ),
+ )
+
+ time = 200
+ self.time_encoding = nn.Embedding.from_pretrained(time_encoding_data(opt.embed_size, time))
+ self.time_mapping = nn.Sequential (
+ nn.Linear ( opt.embed_size, opt.embed_size),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.25 ),
+ nn.Linear ( opt.embed_size, opt.embed_size)
+ )
+
+ self.embed_linear = nn.Sequential (
+ nn.Linear ( opt.embed_size, opt.embed_size),
+ nn.ReLU ( ),
+ # nn.Dropout ( 0.25 ),
+ # nn.Linear ( opt.embed_size, opt.embed_size),
+ # nn.ReLU ( ),
+ nn.Dropout ( 0.25 ),
+ )
+ self.relu = nn.ReLU ( )
+
+ self.linears = nn.Sequential (
+ nn.Linear ( opt.hidden_size * 2, opt.rnn_size ),
+ # nn.ReLU ( ),
+ # nn.Dropout ( 0.25 ),
+ # nn.Linear ( opt.rnn_size, opt.rnn_size ),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.25 ),
+ nn.Linear ( opt.rnn_size, 1),
+ )
+ mn = 128
+ self.master_linear = nn.Sequential (
+ nn.Linear ( 43, mn),
+ # nn.ReLU ( ),
+ # nn.Dropout ( 0.25 ),
+ # nn.Linear ( mn, mn),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.25 ),
+ nn.Linear ( mn, 1),
+ )
+ self.output = nn.Sequential (
+ nn.Linear ( mn + opt.rnn_size , opt.rnn_size),
+ nn.ReLU ( ),
+ nn.Linear ( opt.rnn_size, mn),
+ nn.ReLU ( ),
+ nn.Dropout ( 0.25 ),
+ nn.Linear ( mn, 1),
+ )
+ self.pooling = nn.AdaptiveMaxPool1d(1)
+ self.opt = opt
+
+ def visit_pooling(self, x, mask, time, value=None, trend=None):
+
+
+
+ output = x
+ size = output.size()
+ output = output.view(size[0] * size[1], size[2], output.size(3)) # (bs*98, 72, 512)
+ if args.use_glp:
+ output = torch.transpose(output, 1,2).contiguous() # (bs*98, 512, 72)
+ output = self.pooling(output)
+ else:
+ weight = self.alpha(output) # (bs*98, 72, 1)
+ # print weight.size()
+ weight = weight.view(size[0]*size[1], size[2])
+ # print weight.size()
+ weight = F.softmax(weight)
+ x = weight.data.cpu().numpy()
+ # print x.shape
+ weight = weight.view(size[0]*size[1], size[2], 1).expand(output.size())
+ output = weight * output # (bs*98, 512, 72)
+ # print output.size()
+ output = output.sum(1)
+ # print output.size()
+ # output = torch.transpose(output, 1,2).contiguous()
+ output = output.view(size[0], size[1], size[3])
+
+ # time encoding
+ time = - time.long()
+ time = self.time_encoding(time)
+ time = self.time_mapping(time)
+
+ if self.use_cat:
+ output = torch.cat((output, time), 2)
+ output = self.relu(output)
+ output = self.output_time(output)
+ else:
+ output = output + time
+ output = self.relu(output)
+
+
+
+ return output
+
+
+ def forward_2(self, x, master, mask=None, time=None, phase='train', value=None, trend=None):
+ '''
+ task2
+ '''
+ size = list(x.size())
+ x = x.view(-1)
+ x = self.embedding( x )
+ x = self.embed_linear(x)
+ size.append(-1)
+ x = x.view(size)
+ if mask is not None:
+ x = self.visit_pooling(x, mask, time, value, trend)
+ lstm_out, _ = self.lstm( x )
+ lstm_out = torch.transpose(lstm_out, 1, 2).contiguous() # (bs, 512, 98)
+ mask = self.pooling(mask)
+ # print 'mask', mask.size()
+ pool_out = []
+ mask_out = []
+ time_out = []
+ time = time.data.cpu().numpy()
+ if phase == 'train':
+ start, delta = 4, 6
+ else:
+ start, delta = 1, 1
+ for i in range(start, lstm_out.size(2), delta):
+ pool_out.append(self.pooling(lstm_out[:,:, :i]))
+ mask_out.append(mask[:, i])
+ time_out.append(time[:, i])
+ pool_out.append(self.pooling(lstm_out))
+ mask_out.append(mask[:, 0])
+ time_out.append(np.zeros(size[0]) - 4)
+
+ lstm_out = torch.cat(pool_out, 2) # (bs, 512, 98)
+ mask_out = torch.cat(mask_out, 1) # (bs, 98)
+ time_out = np.array(time_out).transpose() # (bs, 98)
+
+ # print 'lstm_out', lstm_out.size()
+ # print 'mask_out', mask_out.size()
+ # print err
+
+ lstm_out = torch.transpose(lstm_out, 1, 2).contiguous() # (bs, 98, 512)
+
+ out_vital = self.linears(lstm_out)
+ size = list(out_vital.size())
+ out_vital = out_vital.view(size[:2])
+ out_master = self.master_linear(master).expand(size[:2])
+ out = out_vital + out_master
+ return out, mask_out, time_out
+
+ def forward_1(self, x, master, mask=None, time=None, phase='train', value=None, trend=None):
+ # out = self.master_linear(master)
+ size = list(x.size())
+ x = x.view(-1)
+ x = self.embedding( x )
+ # print x.size()
+ x = self.embed_linear(x)
+ size.append(-1)
+ x = x.view(size)
+ if mask is not None:
+ x = self.visit_pooling(x, mask, time, value, trend)
+ lstm_out, _ = self.lstm( x )
+
+ lstm_out = torch.transpose(lstm_out, 1, 2).contiguous()
+ lstm_out = self.pooling(lstm_out)
+ lstm_out = lstm_out.view(lstm_out.size(0), -1)
+
+ out = self.linears(lstm_out) + self.master_linear(master)
+ return out
+
+ def forward(self, x, master, mask=None, time=None, phase='train', value=None, trend=None):
+ if args.task == 'task2':
+ return self.forward_2(x, master, mask, time, phase, value, trend)
+ # return self.forward_1(x, master, mask, time, phase, value, trend)
+ else:
+ return self.forward_1(x, master, mask, time, phase, value, trend)
+
+
+
diff --git a/code/prediction/test.py b/code/prediction/test.py
new file mode 100644
index 0000000..3cbc8d0
--- /dev/null
+++ b/code/prediction/test.py
@@ -0,0 +1,19 @@
+import os
+import sys
+from tools import parse, py_op
+args = parse.args
+
+
+for seed in [1, 40, 500, 100, 2019]:
+ for task in ['case1', 'task1', 'task2']:
+ cmd = 'python main.py --phase test --final 0 --batch-size 8 --task {:s} --seed {:d} --resume ../data/models/{:s}-snm-{:d}-snr-{:d}-value-{:d}-trend-{:d}-cat-{:d}-lt-{:d}-size-{:d}-seed-{:d}-{:s}'.format(task, seed, task,
+ # cmd = 'python main.py --phase valid --batch-size 8 --task {:s} --seed {:d} --resume ../data/models/{:s}-snm-{:d}-snr-{:d}-value-{:d}-trend-{:d}-cat-{:d}-lt-{:d}-size-{:d}-seed-{:d}-{:s}'.format(task, seed, task,
+ args.split_num, args.split_nor, args.use_value, args.use_trend,
+ args.use_cat, args.last_time, args.embed_size, args.seed, 'best.ckpt')
+ print cmd
+ os.system(cmd)
+ print
+ print
+ print
+ break
+
diff --git a/code/prediction/tools/__init__.py b/code/prediction/tools/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/code/prediction/tools/measures.py b/code/prediction/tools/measures.py
new file mode 100644
index 0000000..05e498d
--- /dev/null
+++ b/code/prediction/tools/measures.py
@@ -0,0 +1,172 @@
+# coding=utf8
+import os
+import numpy as np
+from sklearn import metrics
+from PIL import Image
+import traceback
+
+def stati_class_number_true_flase(label, pred):
+ label = np.array(label)
+ pred = np.array(pred)
+
+ cls_list = set(label) | set(pred)
+ d = dict()
+ for cls in cls_list:
+ d[cls] = dict()
+ d[cls]['number'] = np.sum(label==cls)
+ d[cls]['true'] = np.sum(label[label==cls]==pred[label==cls])
+ d[cls]['pred'] = np.sum(pred==cls)
+ return d
+
+def stati_class_number_true_flase_multi_label_margin(labels, preds):
+
+ d = dict()
+ for label, pred in zip(labels, preds):
+ label = set(label[label>=0])
+ for cls in range(len(pred)):
+ if cls not in d:
+ d[cls] = dict()
+ d[cls]['number'] = 0
+ d[cls]['true'] = 0
+ d[cls]['pred'] = 0
+ if cls in label:
+ d[cls]['number'] += 1
+ if pred[cls] > 0.5:
+ d[cls]['true'] += 1
+ if pred[cls] > 0.5:
+ d[cls]['pred'] += 1
+ return d
+
+def stati_class_number_true_flase_bce(labels, preds):
+ d = dict()
+ labels = labels.astype(np.int64).reshape(-1)
+ preds = preds.reshape(-1) > 0
+ index = labels >= 0
+ labels = labels[index]
+ preds = preds[index]
+
+ preds_num = preds.sum(0)
+ true_num = (labels+preds==2).sum(0)
+ for cls in range(2):
+ d[cls] = dict()
+ d[cls]['number'] = (labels==cls).sum()
+ d[cls]['true'] = (labels+preds==2*cls).sum()
+ d[cls]['pred'] = (labels==cls).sum()
+ return d
+
+def measures(d_list):
+ # 合并每一个预测的结果
+ d_all = dict()
+ for d in d_list:
+ for cls in d.keys():
+ if cls not in d_all:
+ d_all[cls] = dict()
+ for k in d[cls].keys():
+ if k not in d_all[cls]:
+ d_all[cls][k] = 0
+ d_all[cls][k] += d[cls][k]
+ m = dict()
+ number = sum([d_all[cls]['number'] for cls in d_all.keys()])
+ for cls in d_all:
+ m[cls] = dict()
+ m[cls]['number'] = d_all[cls]['number']
+ m[cls]['true'] = d_all[cls]['true']
+ m[cls]['pred'] = d_all[cls]['pred']
+ m[cls]['ratio'] = d_all[cls]['number'] / (float(number) + 10e-10)
+ m[cls]['accuracy'] = d_all[cls]['true'] / (float(d_all[cls]['number']) + 10e-10)
+ m[cls]['precision'] = d_all[cls]['true'] /(float(d_all[cls]['pred']) + 10e-10)
+ return m
+
+def print_measures(m, s = 'measures'):
+ print s
+ accuracy = 0
+ for cls in sorted(m.keys()):
+ print '\tclass: {:d}\taccuracy:{:.6f}\tprecision:{:.6f}\tratio:{:.6f}\t\tN/T/P:{:d}/{:d}/{:d}\
+ '.format(cls, m[cls]['accuracy'],m[cls]['precision'],m[cls]['ratio'],m[cls]['number'],m[cls]['true'],m[cls]['pred'])
+ accuracy += m[cls]['accuracy'] * m[cls]['ratio']
+ print '\tacc:{:.6f}'.format(accuracy)
+ return accuracy
+
+def mse(pred_image, image):
+ pred_image = pred_image.reshape(-1).astype(np.float32)
+ image = image.reshape(-1).astype(np.float32)
+ mse_err = metrics.mean_squared_error(pred_image,image)
+ return mse_err
+
+def psnr(pred_image, image):
+ return 10 * np.log10(255*255/mse(pred_image,image))
+
+
+def psnr_pred(stain_vis=20, end= 10000):
+ clean_dir = '../../data/AI/testB/'
+ psnr_list = []
+ f = open('../../data/result.csv','w')
+ for i,clean in enumerate(os.listdir(clean_dir)):
+ clean = os.path.join(clean_dir, clean)
+ clean_file = clean
+ pred = clean.replace('.jpg','.png').replace('data','data/test_clean')
+ stain = clean.replace('trainB','trainA').replace('testB','testA').replace('.jpg','_.jpg')
+
+ try:
+ pred = np.array(Image.open(pred).resize((250,250))).astype(np.float32)
+ clean = np.array(Image.open(clean).resize((250,250))).astype(np.float32)
+ stain = np.array(Image.open(stain).resize((250,250))).astype(np.float32)
+
+ # diff = np.abs(stain - pred)
+ # vis = 20
+ # pred[diffgray_vis] = stain[stain>gray_vis]
+
+ if end < 1000:
+ diff = np.abs(clean - stain)
+ # stain[diff>stain_vis] = pred[diff>stain_vis]
+ stain[diff>stain_vis] = clean[diff>stain_vis]
+
+ psnr_pred = psnr(clean, pred)
+ psnr_stain = psnr(clean, stain)
+ psnr_list.append([psnr_stain, psnr_pred])
+ except:
+ continue
+ if i>end:
+ break
+ print i, min(end, 1000)
+
+ f.write(clean_file.split('/')[-1].split('.')[0])
+ f.write(',')
+ f.write(str(psnr_stain))
+ f.write(',')
+ f.write(str(psnr_pred))
+ f.write(',')
+ f.write(str(psnr_pred/psnr_stain - 1))
+ f.write('\n')
+ # print '预测',np.mean(psnr_list)
+ psnr_list = np.array(psnr_list)
+ psnr_mean = ((psnr_list[:,1] - psnr_list[:,0]) / psnr_list[:,0]).mean()
+ if end > 1000:
+ print '网纹图PSNR', psnr_list[:,0].mean()
+ print '预测图PSNR', psnr_list[:,1].mean()
+ print '增益率', psnr_mean
+ f.write(str(psnr_mean))
+ f.close()
+ return psnr_list[:,0].mean()
+
+def main():
+ pmax = [0.,0.]
+ for vis in range(1, 30):
+ p = psnr_pred(vis, 10)
+ print vis, p
+ if p > pmax[1]:
+ pmax = [vis, p]
+ print '...'
+ # print 256,psnr_pred(256)
+ print pmax
+ # print 10 * np.log10(255*255/metrics.mean_squared_error([3],[9]))
+
+
+if __name__ == '__main__':
+ psnr_pred(4000)
+ # main()
+ # for v in range(1,10):
+ # print v, 10 * np.log10(255*255/v/v)
diff --git a/code/prediction/tools/parse.py b/code/prediction/tools/parse.py
new file mode 100644
index 0000000..d9250fa
--- /dev/null
+++ b/code/prediction/tools/parse.py
@@ -0,0 +1,233 @@
+# coding=utf8
+
+import argparse
+
+parser = argparse.ArgumentParser(description='DII Challenge 2019')
+
+parser.add_argument(
+ '--data-dir',
+ type=str,
+ default='/home/yin/data/',
+ help='data directory'
+ )
+parser.add_argument(
+ '--result-dir',
+ type=str,
+ default='../result/',
+ help='result directory'
+ )
+parser.add_argument(
+ '--file-dir',
+ type=str,
+ default='../file/',
+ help='useful file directory'
+ )
+parser.add_argument(
+ '--vital-file',
+ type=str,
+ default='../file/vital.csv',
+ help='vital information'
+ )
+parser.add_argument(
+ '--master-file',
+ type=str,
+ default='../file/master.csv',
+ help='master information'
+ )
+parser.add_argument(
+ '--label-file',
+ type=str,
+ default='../file/label.csv',
+ help='label'
+ )
+parser.add_argument(
+ '--model',
+ '-m',
+ type=str,
+ default='lstm',
+ help='model'
+ )
+parser.add_argument(
+ '--embed-size',
+ metavar='EMBED SIZE',
+ type=int,
+ default=256,
+ help='embed size'
+ )
+parser.add_argument(
+ '--rnn-size',
+ metavar='rnn SIZE',
+ type=int,
+ help='rnn size'
+ )
+parser.add_argument(
+ '--hidden-size',
+ metavar='hidden SIZE',
+ type=int,
+ help='hidden size'
+ )
+parser.add_argument(
+ '--split-num',
+ metavar='split num',
+ type=int,
+ default=5,
+ help='split num'
+ )
+parser.add_argument(
+ '--split-nor',
+ metavar='split normal range',
+ type=int,
+ default=3,
+ help='split num'
+ )
+parser.add_argument(
+ '--num-layers',
+ metavar='num layers',
+ type=int,
+ default=2,
+ help='num layers'
+ )
+parser.add_argument(
+ '--num-code',
+ metavar='num codes',
+ type=int,
+ default=1200,
+ help='num code'
+ )
+parser.add_argument(
+ '--use-glp',
+ metavar='use global pooling operation',
+ type=int,
+ default=1,
+ help='use global pooling operation'
+ )
+parser.add_argument(
+ '--use-visit',
+ metavar='use visit as input',
+ type=int,
+ default=1,
+ help='use visit as input'
+ )
+parser.add_argument(
+ '--use-value',
+ metavar='use value embedding as input',
+ type=int,
+ default=1,
+ help='use value embedding as input'
+ )
+parser.add_argument(
+ '--use-cat',
+ metavar='use cat for time and value embedding',
+ type=int,
+ default=1,
+ help='use cat or add'
+ )
+parser.add_argument(
+ '--use-trend',
+ metavar='use feature variation trend',
+ type=int,
+ default=1,
+ help='use trend'
+ )
+parser.add_argument(
+ '--avg-time',
+ metavar='avg time for trend, hours',
+ type=int,
+ default=4,
+ help='avg time for trend'
+ )
+parser.add_argument(
+ '--seed',
+ metavar='seed',
+ type=int,
+ default=1,
+ help='seed'
+ )
+parser.add_argument(
+ '--set',
+ metavar='split set for training',
+ type=int,
+ default=0,
+ help='split set'
+ )
+parser.add_argument(
+ '--last-time',
+ metavar='last-time for task2',
+ type=int,
+ default=-4,
+ help='last time'
+ )
+parser.add_argument(
+ '--final',
+ metavar='final test to submit',
+ type=int,
+ default=0,
+ help='final'
+ )
+
+
+
+
+parser.add_argument('--phase',
+ default='train',
+ type=str,
+ metavar='S',
+ help='pretrain/train/test phase')
+parser.add_argument(
+ '--batch-size',
+ '-b',
+ metavar='BATCH SIZE',
+ type=int,
+ default=32,
+ help='batch size'
+ )
+parser.add_argument('--save-dir',
+ default='../../data',
+ type=str,
+ metavar='S',
+ help='save dir')
+parser.add_argument('--resume',
+ default='',
+ type=str,
+ metavar='S',
+ help='start from checkpoints')
+parser.add_argument('--task',
+ default='task1',
+ type=str,
+ metavar='S',
+ help='start from checkpoints')
+
+#####
+parser.add_argument('-j',
+ '--workers',
+ default=8,
+ type=int,
+ metavar='N',
+ help='number of data loading workers (default: 32)')
+parser.add_argument('--lr',
+ '--learning-rate',
+ default=0.0001,
+ type=float,
+ metavar='LR',
+ help='initial learning rate')
+parser.add_argument('--epochs',
+ default=20,
+ type=int,
+ metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--save-freq',
+ default='5',
+ type=int,
+ metavar='S',
+ help='save frequency')
+parser.add_argument('--save-pred-freq',
+ default='10',
+ type=int,
+ metavar='S',
+ help='save pred clean frequency')
+parser.add_argument('--val-freq',
+ default='5',
+ type=int,
+ metavar='S',
+ help='val frequency')
+args = parser.parse_args()
diff --git a/code/prediction/tools/plot.py b/code/prediction/tools/plot.py
new file mode 100644
index 0000000..9f36c53
--- /dev/null
+++ b/code/prediction/tools/plot.py
@@ -0,0 +1,31 @@
+# coding=utf8
+import matplotlib.pyplot as plt
+import numpy as np
+
+def plot_multi_graph(image_list, name_list, save_path=None, show=False):
+ graph_place = int(np.sqrt(len(name_list) - 1)) + 1
+ for i, (image, name) in enumerate(zip(image_list, name_list)):
+ ax1 = plt.subplot(graph_place,graph_place,i+1)
+ ax1.set_title(name)
+ # plt.imshow(image,cmap='gray')
+ plt.imshow(image)
+ plt.axis('off')
+ if save_path:
+ plt.savefig(save_path)
+ pass
+ if show:
+ plt.show()
+
+def plot_multi_line(x_list, y_list, name_list, save_path=None, show=False):
+ graph_place = int(np.sqrt(len(name_list) - 1)) + 1
+ for i, (x, y, name) in enumerate(zip(x_list, y_list, name_list)):
+ ax1 = plt.subplot(graph_place,graph_place,i+1)
+ ax1.set_title(name)
+ plt.plot(x,y)
+ # plt.imshow(image,cmap='gray')
+ if save_path:
+ plt.savefig(save_path)
+ if show:
+ plt.show()
+
+
diff --git a/code/prediction/tools/py_op.py b/code/prediction/tools/py_op.py
new file mode 100644
index 0000000..0ad90a9
--- /dev/null
+++ b/code/prediction/tools/py_op.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+"""
+此文件用于常用python函数的使用
+"""
+import os
+import json
+import traceback
+from collections import OrderedDict
+import random
+
+import sys
+# reload(sys)
+# sys.setdefaultencoding('utf-8')
+
+################################################################################
+### pre define variables
+#:: enumerate
+#:: raw_input
+#:: listdir
+#:: sorted
+### pre define function
+def mywritejson(save_path,content):
+ content = json.dumps(content,indent=4,ensure_ascii=False)
+ with open(save_path,'w') as f:
+ f.write(content)
+
+def myreadjson(load_path):
+ with open(load_path,'r') as f:
+ return json.loads(f.read())
+
+def mywritefile(save_path,content):
+ with open(save_path,'w') as f:
+ f.write(content)
+
+def myreadfile(load_path):
+ with open(load_path,'r') as f:
+ return f.read()
+
+def myprint(content):
+ print(json.dumps(content,indent=4,ensure_ascii=False))
+
+def rm(fi):
+ os.system('rm ' + fi)
+
+def mystrip(s):
+ return ''.join(s.split())
+
+def mysorteddict(d,key = lambda s:s, reverse=False):
+ dordered = OrderedDict()
+ for k in sorted(d.keys(),key = key,reverse=reverse):
+ dordered[k] = d[k]
+ return dordered
+
+def mysorteddictfile(src,obj):
+ mywritejson(obj,mysorteddict(myreadjson(src)))
+
+def myfuzzymatch(srcs,objs,grade=80):
+ matchDict = OrderedDict()
+ for src in srcs:
+ for obj in objs:
+ value = fuzz.partial_ratio(src,obj)
+ if value > grade:
+ try:
+ matchDict[src].append(obj)
+ except:
+ matchDict[src] = [obj]
+ return matchDict
+
+def mydumps(x):
+ return json.dumps(content,indent=4,ensure_ascii=False)
+
+def get_random_list(l,num=-1,isunique=0):
+ if isunique:
+ l = set(l)
+ if num < 0:
+ num = len(l)
+ if isunique and num > len(l):
+ return
+ lnew = []
+ l = list(l)
+ while(num>len(lnew)):
+ x = l[int(random.random()*len(l))]
+ if isunique and x in lnew:
+ continue
+ lnew.append(x)
+ return lnew
+
+def fuzz_list(node1_list,node2_list,score_baseline=66,proposal_num=10,string_map=None):
+ node_dict = { }
+ for i,node1 in enumerate(node1_list):
+ match_score_dict = { }
+ for node2 in node2_list:
+ if node1 != node2:
+ if string_map is not None:
+ n1 = string_map(node1)
+ n2 = string_map(node2)
+ score = fuzz.partial_ratio(n1,n2)
+ if n1 == n2:
+ node2_list.remove(node2)
+ else:
+ score = fuzz.partial_ratio(node1,node2)
+ if score > score_baseline:
+ match_score_dict[node2] = score
+ else:
+ node2_list.remove(node2)
+ node2_sort = sorted(match_score_dict.keys(), key=lambda k:match_score_dict[k],reverse=True)
+ node_dict[node1] = [[n,match_score_dict[n]] for n in node2_sort[:proposal_num]]
+ print(i,len(node1_list))
+ return node_dict, node2_list
+
+def swap(a,b):
+ return b, a
+
+def mkdir(d):
+ path = d.split('/')
+ for i in range(len(path)):
+ d = '/'.join(path[:i+1])
+ if not os.path.exists(d):
+ os.mkdir(d)
+
diff --git a/code/prediction/tools/segmentation.py b/code/prediction/tools/segmentation.py
new file mode 100644
index 0000000..52119e1
--- /dev/null
+++ b/code/prediction/tools/segmentation.py
@@ -0,0 +1,153 @@
+# coding=utf8
+import matplotlib.pyplot as plt
+from scipy import ndimage as ndi
+from skimage import morphology,color,data
+from skimage import filters
+import numpy as np
+import skimage
+import os
+from skimage import measure
+
+
+
+def watershed(image, label=None):
+ denoised = filters.rank.median(image, morphology.disk(2)) #过滤噪声
+ #将梯度值低于10的作为开始标记点
+ markers = filters.rank.gradient(denoised, morphology.disk(5)) < 10
+ markers = ndi.label(markers)[0]
+
+ gradient = filters.rank.gradient(denoised, morphology.disk(2)) #计算梯度
+ labels =morphology.watershed(gradient, markers, mask=image) #基于梯度的分水岭算法
+
+ fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(6, 6))
+ axes = axes.ravel()
+ ax0, ax1, ax2, ax3 = axes
+
+ ax0.imshow(image, cmap=plt.cm.gray, interpolation='nearest')
+ ax0.set_title("Original")
+ # ax1.imshow(gradient, cmap=plt.cm.spectral, interpolation='nearest')
+ ax1.imshow(gradient, cmap=plt.cm.gray, interpolation='nearest')
+ ax1.set_title("Gradient")
+ if label is not None:
+ # ax2.imshow(markers, cmap=plt.cm.spectral, interpolation='nearest')
+ ax2.imshow(label, cmap=plt.cm.gray, interpolation='nearest')
+ else:
+ ax2.imshow(markers, cmap=plt.cm.spectral, interpolation='nearest')
+ ax2.set_title("Markers")
+ ax3.imshow(labels, cmap=plt.cm.spectral, interpolation='nearest')
+ ax3.set_title("Segmented")
+
+ for ax in axes:
+ ax.axis('off')
+
+ fig.tight_layout()
+ plt.show()
+
+def plot_4(image, gradient,label,segmentation, save_path=None):
+ fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(6, 6))
+ axes = axes.ravel()
+ ax0, ax1, ax2, ax3 = axes
+ ax0.imshow(image, cmap=plt.cm.gray, interpolation='nearest')
+ ax0.set_title("Original")
+ ax1.imshow(gradient, cmap=plt.cm.gray, interpolation='nearest')
+ ax1.set_title("Gradient")
+ ax2.imshow(label, cmap=plt.cm.gray, interpolation='nearest')
+ ax2.set_title("label")
+ ax3.imshow(segmentation, cmap=plt.cm.spectral, interpolation='nearest')
+ ax3.set_title("Segmented")
+
+ for ax in axes:
+ ax.axis('off')
+
+ fig.tight_layout()
+ if save_path:
+ print save_path
+ plt.savefig(save_path)
+ else:
+ plt.show()
+
+def fill(image):
+ '''
+ 填充图片内部空白
+ 临时写的函数
+ 建议后期替换
+ '''
+ label_img = measure.label(image, background=1)
+ props = measure.regionprops(label_img)
+ max_area = np.array([p.area for p in props]).max()
+ for i,prop in enumerate(props):
+ if prop.area < max_area:
+ image[prop.coords[:,0],prop.coords[:,1]] = 1
+ return image
+
+
+
+def my_watershed(image, label=None, min_gray=480, max_gray=708, min_gradient=5, show=False, save_path='/tmp/x.jpg'):
+ image = image - min_gray
+ image[image>max_gray] = 0
+ image[image< 10] = 0
+ image = image * 5
+
+ denoised = filters.rank.median(image, morphology.disk(2)) #过滤噪声
+ #将梯度值低于10的作为开始标记点
+ markers = filters.rank.gradient(denoised, morphology.disk(5)) < 10
+ markers = ndi.label(markers)[0]
+
+ gradient = filters.rank.gradient(denoised, morphology.disk(2)) #计算梯度
+ labels = gradient > min_gradient
+
+ mask = gradient > min_gradient
+ label_img = measure.label(mask, background=0)
+ props = measure.regionprops(label_img)
+ pred = np.zeros_like(gradient)
+ for i,prop in enumerate(props):
+ if prop.area > 50:
+ region = np.array(prop.coords)
+ vx,vy = region.var(0)
+ v = vx + vy
+ if v < 200:
+ pred[prop.coords[:,0],prop.coords[:,1]] = 1
+
+ # 填充边缘内部空白
+ pred = fill(pred)
+
+ if show:
+ plot_4(image, gradient, label, pred)
+ else:
+ plot_4(image, gradient, label, pred, save_path)
+
+ return pred
+
+def segmentation(image_npy, label_npy,save_path):
+ print image_npy
+ image = np.load(image_npy)
+ label = np.load(label_npy)
+ if np.sum(label) == 0:
+ return
+ min_gray,max_gray = 480, 708
+ my_watershed(image,label,min_gray, max_gray,show=False, save_path=save_path)
+
+def main():
+ data_dir = '/home/yin/all/PVL_DATA/preprocessed/2D/'
+ save_dir = '/home/yin/all/PVL_DATA/tool_result/'
+ os.system('rm -r ' + save_dir)
+ os.system('mkdir ' + save_dir)
+ for patient in os.listdir(data_dir):
+ patient_dir = os.path.join(data_dir, patient)
+ for f in os.listdir(patient_dir):
+ if 'roi.npy' in f:
+ label_npy = os.path.join(patient_dir,f)
+ image_npy = label_npy.replace('.roi.npy','.npy')
+ segmentation(image_npy,label_npy, os.path.join(save_dir,label_npy.strip('/').replace('/','.').replace('npy','jpg')))
+
+if __name__ == '__main__':
+ # image =color.rgb2gray(data.camera())
+ # watershed(image)
+ main()
+ image_npy = '/home/yin/all/PVL_DATA/preprocessed/2D/JD_chen_xi/23.npy'
+ image_npy = '/home/yin/all/PVL_DATA/preprocessed/2D/JD_chen_xi/14.npy'
+ image_npy = '/home/yin/all/PVL_DATA/preprocessed/2D/JD_zhang_yu_chen/23.npy'
+ label_npy = image_npy.replace('.npy','.roi.npy')
+ segmentation(image_npy,label_npy)
+
+
diff --git a/code/prediction/tools/utils.py b/code/prediction/tools/utils.py
new file mode 100644
index 0000000..dfbd5aa
--- /dev/null
+++ b/code/prediction/tools/utils.py
@@ -0,0 +1,368 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2017 www.drcubic.com, Inc. All Rights Reserved
+#
+"""
+File: utils.py
+Author: shileicao(shileicao@stu.xjtu.edu.cn)
+Date: 2017-06-20 14:56:54
+
+**Note.** This code absorb some code from following source.
+1. [DSB2017](https://github.com/lfz/DSB2017)
+"""
+
+import os
+import sys
+
+import numpy as np
+import torch
+
+
+def getFreeId():
+ import pynvml
+
+ pynvml.nvmlInit()
+
+ def getFreeRatio(id):
+ handle = pynvml.nvmlDeviceGetHandleByIndex(id)
+ use = pynvml.nvmlDeviceGetUtilizationRates(handle)
+ ratio = 0.5 * (float(use.gpu + float(use.memory)))
+ return ratio
+
+ deviceCount = pynvml.nvmlDeviceGetCount()
+ available = []
+ for i in range(deviceCount):
+ if getFreeRatio(i) < 70:
+ available.append(i)
+ gpus = ''
+ for g in available:
+ gpus = gpus + str(g) + ','
+ gpus = gpus[:-1]
+ return gpus
+
+
+def setgpu(gpuinput):
+ freeids = getFreeId()
+ if gpuinput == 'all':
+ gpus = freeids
+ else:
+ gpus = gpuinput
+ busy_gpu = [g not in freeids for g in gpus.split(',')]
+ if any(busy_gpu):
+ raise ValueError('gpu' + ' '.join(busy_gpu) + 'is being used')
+ print('using gpu ' + gpus)
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpus
+ return len(gpus.split(','))
+
+
+def error_mask_stats(labels, filenames):
+ error_f = []
+ for i, f in enumerate(filenames):
+# if not np.all(labels[i] > 0):
+# error_f.append(f)
+ for bbox_i in range(labels[i].shape[0]):
+ imgs = np.load(f)
+ if not np.all(
+ np.array(imgs.shape[1:]) - labels[i][bbox_i][:-1] > 0):
+ error_f.append(f)
+ error_f = list(set(error_f))
+ fileid_list = [os.path.split(filename)[1].split('_')[0]
+ for filename in error_f]
+ print("','".join(fileid_list))
+ return error_f
+
+
+class Logger(object):
+ def __init__(self, logfile):
+ self.terminal = sys.stdout
+ self.log = open(logfile, "a")
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.log.write(message)
+
+ def flush(self):
+ #this flush method is needed for python 3 compatibility.
+ #this handles the flush command by doing nothing.
+ #you might want to specify some extra behavior here.
+ pass
+
+
+def split4(data, max_stride, margin):
+ splits = []
+ data = torch.Tensor.numpy(data)
+ _, c, z, h, w = data.shape
+
+ w_width = np.ceil(float(w / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ h_width = np.ceil(float(h / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ pad = int(np.ceil(float(z) / max_stride) * max_stride) - z
+ leftpad = pad / 2
+ pad = [[0, 0], [0, 0], [leftpad, pad - leftpad], [0, 0], [0, 0]]
+ data = np.pad(data, pad, 'constant', constant_values=-1)
+ data = torch.from_numpy(data)
+ splits.append(data[:, :, :, :h_width, :w_width])
+ splits.append(data[:, :, :, :h_width, -w_width:])
+ splits.append(data[:, :, :, -h_width:, :w_width])
+ splits.append(data[:, :, :, -h_width:, -w_width:])
+
+ return torch.cat(splits, 0)
+
+
+def combine4(output, h, w):
+ splits = []
+ for i in range(len(output)):
+ splits.append(output[i])
+
+ output = np.zeros(
+ (splits[0].shape[0], h, w, splits[0].shape[3],
+ splits[0].shape[4]), np.float32)
+
+ h0 = output.shape[1] / 2
+ h1 = output.shape[1] - h0
+ w0 = output.shape[2] / 2
+ w1 = output.shape[2] - w0
+
+ splits[0] = splits[0][:, :h0, :w0, :, :]
+ output[:, :h0, :w0, :, :] = splits[0]
+
+ splits[1] = splits[1][:, :h0, -w1:, :, :]
+ output[:, :h0, -w1:, :, :] = splits[1]
+
+ splits[2] = splits[2][:, -h1:, :w0, :, :]
+ output[:, -h1:, :w0, :, :] = splits[2]
+
+ splits[3] = splits[3][:, -h1:, -w1:, :, :]
+ output[:, -h1:, -w1:, :, :] = splits[3]
+
+ return output
+
+
+def split8(data, max_stride, margin):
+ splits = []
+ if isinstance(data, np.ndarray):
+ c, z, h, w = data.shape
+ else:
+ _, c, z, h, w = data.size()
+
+ z_width = np.ceil(float(z / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ w_width = np.ceil(float(w / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ h_width = np.ceil(float(h / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ for zz in [[0, z_width], [-z_width, None]]:
+ for hh in [[0, h_width], [-h_width, None]]:
+ for ww in [[0, w_width], [-w_width, None]]:
+ if isinstance(data, np.ndarray):
+ splits.append(data[np.newaxis, :, zz[0]:zz[1], hh[0]:hh[1],
+ ww[0]:ww[1]])
+ else:
+ splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:
+ ww[1]])
+
+ if isinstance(data, np.ndarray):
+ return np.concatenate(splits, 0)
+ else:
+ return torch.cat(splits, 0)
+
+
+def combine8(output, z, h, w):
+ splits = []
+ for i in range(len(output)):
+ splits.append(output[i])
+
+ output = np.zeros(
+ (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32)
+
+ z_width = z / 2
+ h_width = h / 2
+ w_width = w / 2
+ i = 0
+ for zz in [[0, z_width], [z_width - z, None]]:
+ for hh in [[0, h_width], [h_width - h, None]]:
+ for ww in [[0, w_width], [w_width - w, None]]:
+ output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[
+ i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :]
+ i = i + 1
+
+ return output
+
+
+def split16(data, max_stride, margin):
+ splits = []
+ _, c, z, h, w = data.size()
+
+ z_width = np.ceil(float(z / 4 + margin) /
+ max_stride).astype('int') * max_stride
+ z_pos = [z * 3 / 8 - z_width / 2, z * 5 / 8 - z_width / 2]
+ h_width = np.ceil(float(h / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ w_width = np.ceil(float(w / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width],
+ [z_pos[1], z_pos[1] + z_width], [-z_width, None]]:
+ for hh in [[0, h_width], [-h_width, None]]:
+ for ww in [[0, w_width], [-w_width, None]]:
+ splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[
+ 1]])
+
+ return torch.cat(splits, 0)
+
+
+def combine16(output, z, h, w):
+ splits = []
+ for i in range(len(output)):
+ splits.append(output[i])
+
+ output = np.zeros(
+ (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32)
+
+ z_width = z / 4
+ h_width = h / 2
+ w_width = w / 2
+ splitzstart = splits[0].shape[0] / 2 - z_width / 2
+ z_pos = [z * 3 / 8 - z_width / 2, z * 5 / 8 - z_width / 2]
+ i = 0
+ for zz, zz2 in zip(
+ [[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3],
+ [z_width * 3 - z, None]],
+ [[0, z_width], [splitzstart, z_width + splitzstart],
+ [splitzstart, z_width + splitzstart], [z_width * 3 - z, None]]):
+ for hh in [[0, h_width], [h_width - h, None]]:
+ for ww in [[0, w_width], [w_width - w, None]]:
+ output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[
+ i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], :, :]
+ i = i + 1
+
+ return output
+
+
+def split32(data, max_stride, margin):
+ splits = []
+ _, c, z, h, w = data.size()
+
+ z_width = np.ceil(float(z / 2 + margin) /
+ max_stride).astype('int') * max_stride
+ w_width = np.ceil(float(w / 4 + margin) /
+ max_stride).astype('int') * max_stride
+ h_width = np.ceil(float(h / 4 + margin) /
+ max_stride).astype('int') * max_stride
+
+ w_pos = [w * 3 / 8 - w_width / 2, w * 5 / 8 - w_width / 2]
+ h_pos = [h * 3 / 8 - h_width / 2, h * 5 / 8 - h_width / 2]
+
+ for zz in [[0, z_width], [-z_width, None]]:
+ for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width],
+ [h_pos[1], h_pos[1] + h_width], [-h_width, None]]:
+ for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width],
+ [w_pos[1], w_pos[1] + w_width], [-w_width, None]]:
+ splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[
+ 1]])
+
+ return torch.cat(splits, 0)
+
+
+def combine32(splits, z, h, w):
+
+ output = np.zeros(
+ (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32)
+
+ z_width = int(np.ceil(float(z) / 2))
+ h_width = int(np.ceil(float(h) / 4))
+ w_width = int(np.ceil(float(w) / 4))
+ splithstart = splits[0].shape[1] / 2 - h_width / 2
+ splitwstart = splits[0].shape[2] / 2 - w_width / 2
+
+ i = 0
+ for zz in [[0, z_width], [z_width - z, None]]:
+
+ for hh, hh2 in zip(
+ [[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3],
+ [h_width * 3 - h, None]],
+ [[0, h_width], [splithstart, h_width + splithstart],
+ [splithstart, h_width + splithstart], [h_width * 3 - h, None]]):
+
+ for ww, ww2 in zip(
+ [[0, w_width], [w_width, w_width * 2],
+ [w_width * 2, w_width * 3], [w_width * 3 - w, None]],
+ [[0, w_width], [splitwstart, w_width + splitwstart],
+ [splitwstart, w_width + splitwstart],
+ [w_width * 3 - w, None]]):
+
+ output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[
+ i][zz[0]:zz[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :]
+ i = i + 1
+
+ return output
+
+
+def split64(data, max_stride, margin):
+ splits = []
+ _, c, z, h, w = data.size()
+
+ z_width = np.ceil(float(z / 4 + margin) /
+ max_stride).astype('int') * max_stride
+ w_width = np.ceil(float(w / 4 + margin) /
+ max_stride).astype('int') * max_stride
+ h_width = np.ceil(float(h / 4 + margin) /
+ max_stride).astype('int') * max_stride
+
+ z_pos = [z * 3 / 8 - z_width / 2, z * 5 / 8 - z_width / 2]
+ w_pos = [w * 3 / 8 - w_width / 2, w * 5 / 8 - w_width / 2]
+ h_pos = [h * 3 / 8 - h_width / 2, h * 5 / 8 - h_width / 2]
+
+ for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width],
+ [z_pos[1], z_pos[1] + z_width], [-z_width, None]]:
+ for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width],
+ [h_pos[1], h_pos[1] + h_width], [-h_width, None]]:
+ for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width],
+ [w_pos[1], w_pos[1] + w_width], [-w_width, None]]:
+ splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[
+ 1]])
+
+ return torch.cat(splits, 0)
+
+
+def combine64(output, z, h, w):
+ splits = []
+ for i in range(len(output)):
+ splits.append(output[i])
+
+ output = np.zeros(
+ (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32)
+
+ z_width = int(np.ceil(float(z) / 4))
+ h_width = int(np.ceil(float(h) / 4))
+ w_width = int(np.ceil(float(w) / 4))
+ splitzstart = splits[0].shape[0] / 2 - z_width / 2
+ splithstart = splits[0].shape[1] / 2 - h_width / 2
+ splitwstart = splits[0].shape[2] / 2 - w_width / 2
+
+ i = 0
+ for zz, zz2 in zip(
+ [[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3],
+ [z_width * 3 - z, None]],
+ [[0, z_width], [splitzstart, z_width + splitzstart],
+ [splitzstart, z_width + splitzstart], [z_width * 3 - z, None]]):
+
+ for hh, hh2 in zip(
+ [[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3],
+ [h_width * 3 - h, None]],
+ [[0, h_width], [splithstart, h_width + splithstart],
+ [splithstart, h_width + splithstart], [h_width * 3 - h, None]]):
+
+ for ww, ww2 in zip(
+ [[0, w_width], [w_width, w_width * 2],
+ [w_width * 2, w_width * 3], [w_width * 3 - w, None]],
+ [[0, w_width], [splitwstart, w_width + splitwstart],
+ [splitwstart, w_width + splitwstart],
+ [w_width * 3 - w, None]]):
+
+ output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[
+ i][zz2[0]:zz2[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :]
+ i = i + 1
+
+ return output
diff --git a/code/prediction/train_prediction.py b/code/prediction/train_prediction.py
new file mode 100644
index 0000000..9999d9a
--- /dev/null
+++ b/code/prediction/train_prediction.py
@@ -0,0 +1,245 @@
+# coding=utf8
+
+
+'''
+main.py 为程序入口
+'''
+
+
+# 基本依赖包
+import os
+import sys
+import time
+import json
+import traceback
+import numpy as np
+from glob import glob
+from tqdm import tqdm
+from tools import parse, py_op
+
+
+# torch
+import torch
+import torchvision
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+
+
+# 自定义文件
+import loss
+import models
+import function
+import loaddata
+# import framework
+from loaddata import dataloader
+from models import lstm
+
+
+# 全局变量
+args = parse.args
+args.hard_mining = 0
+args.gpu = 1
+args.use_trend = max(args.use_trend, args.use_value)
+args.use_value = max(args.use_trend, args.use_value)
+args.rnn_size = args.embed_size
+args.hidden_size = args.embed_size
+
+def train_eval(p_dict, phase='train'):
+ ### 传入参数
+ epoch = p_dict['epoch']
+ model = p_dict['model'] # 模型
+ loss = p_dict['loss'] # loss 函数
+ if phase == 'train':
+ data_loader = p_dict['train_loader'] # 训练数据
+ optimizer = p_dict['optimizer'] # 优化器
+ else:
+ data_loader = p_dict['val_loader']
+
+ ### 局部变量定义
+ classification_metric_dict = dict()
+ # if args.task == 'case1':
+
+ for i,data in enumerate(tqdm(data_loader)):
+ if args.use_visit:
+ if args.gpu:
+ data = [ Variable(x.cuda()) for x in data ]
+ visits, values, mask, master, labels, times, trends = data
+ if i == 0:
+ print 'input size', visits.size()
+ output = model(visits, master, mask, times, phase, values, trends)
+ else:
+ inputs = Variable(data[0].cuda())
+ labels = Variable(data[1].cuda())
+ output = model(inputs)
+
+ # if 0:
+ if args.task == 'task2':
+ output, mask, time = output
+ labels = labels.unsqueeze(-1).expand(output.size()).contiguous()
+ labels[mask==0] = -1
+ else:
+ time = None
+
+ classification_loss_output = loss(output, labels, args.hard_mining)
+ loss_gradient = classification_loss_output[0]
+ # 计算性能指标
+ function.compute_metric(output, labels, time, classification_loss_output, classification_metric_dict, phase)
+
+ # print(outputs.size(), labels.size(),data[3].size(),segment_line_output.size())
+ # print('detection', detect_character_labels.size(), detect_character_output.size())
+ # return
+
+ # 训练阶段
+ if phase == 'train':
+ optimizer.zero_grad()
+ loss_gradient.backward()
+ optimizer.step()
+
+ # if i >= 10:
+ # break
+
+
+ print('\nEpoch: {:d} \t Phase: {:s} \n'.format(epoch, phase))
+ metric = function.print_metric('classification', classification_metric_dict, phase)
+ if args.phase != 'train':
+ print 'metric = ', metric
+ print
+ print
+ return
+ if phase == 'val':
+ if metric > p_dict['best_metric'][0]:
+ p_dict['best_metric'] = [metric, epoch]
+ function.save_model(p_dict)
+ if 0:
+ # if args.task == 'task2':
+ preds = classification_metric_dict['preds']
+ labels = classification_metric_dict['labels']
+ times = classification_metric_dict['times']
+ fl = open('../result/tauc_label.csv', 'w')
+ fr = open('../result/tauc_result.csv', 'w')
+ fl.write('adm_id,last_event_time,mortality\n')
+ fr.write('adm_id,probability\n')
+ for i, (p,l,t) in enumerate(zip(preds, labels, times)):
+ if i % 30:
+ continue
+ fl.write(str(i) + ',')
+ fl.write(str(t) + ',')
+ fl.write(str(int(l)) + '\n')
+
+ fr.write(str(i) + ',')
+ fr.write(str(p) + '\n')
+
+
+ print('valid: metric: {:3.4f}\t epoch: {:d}\n'.format(metric, epoch))
+ print('\t\t\t valid: best_metric: {:3.4f}\t epoch: {:d}\n'.format(p_dict['best_metric'][0], p_dict['best_metric'][1]))
+ else:
+ print('train: metric: {:3.4f}\t epoch: {:d}\n'.format(metric, epoch))
+
+
+
+def main():
+ p_dict = dict() # All the parameters
+ p_dict['args'] = args
+ args.split_nn = args.split_num + args.split_nor * 3
+ args.vocab_size = args.split_nn * 145 + 1
+ print 'vocab_size', args.vocab_size
+
+ ### load data
+ print 'read data ...'
+ patient_time_record_dict = py_op.myreadjson(os.path.join(args.result_dir, 'patient_time_record_dict.json'))
+ patient_master_dict = py_op.myreadjson(os.path.join(args.result_dir, 'patient_master_dict.json'))
+ patient_label_dict = py_op.myreadjson(os.path.join(args.result_dir, 'patient_label_dict.json'))
+
+ patient_train = list(json.load(open(os.path.join(args.file_dir, args.task, 'train.json'))))
+ patient_valid = list(json.load(open(os.path.join(args.file_dir, args.task, 'val.json'))))
+
+ if len(patient_train) > len(patient_label_dict):
+ patients = patient_time_record_dict.keys()
+ patients = patient_label_dict.keys()
+ n = int(0.8 * len(patients))
+ patient_train = patients[:n]
+ patient_valid = patients[n:]
+
+
+
+
+
+ print 'data loading ...'
+ train_dataset = dataloader.DataSet(
+ patient_train,
+ patient_time_record_dict,
+ patient_label_dict,
+ patient_master_dict,
+ args=args,
+ phase='train')
+ train_loader = DataLoader(
+ dataset=train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ num_workers=8,
+ pin_memory=True)
+ val_dataset = dataloader.DataSet(
+ patient_valid,
+ patient_time_record_dict,
+ patient_label_dict,
+ patient_master_dict,
+ args=args,
+ phase='val')
+ val_loader = DataLoader(
+ dataset=val_dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=8,
+ pin_memory=True)
+
+ p_dict['train_loader'] = train_loader
+ p_dict['val_loader'] = val_loader
+
+
+
+ cudnn.benchmark = True
+ net = lstm.LSTM(args)
+ if args.gpu:
+ net = net.cuda()
+ p_dict['loss'] = loss.Loss().cuda()
+ else:
+ p_dict['loss'] = loss.Loss()
+
+ parameters = []
+ for p in net.parameters():
+ parameters.append(p)
+ optimizer = torch.optim.Adam(parameters, lr=args.lr)
+ p_dict['optimizer'] = optimizer
+ p_dict['model'] = net
+ start_epoch = 0
+ # args.epoch = start_epoch
+ # print ('best_f1score' + str(best_f1score))
+
+ p_dict['epoch'] = 0
+ p_dict['best_metric'] = [0, 0]
+
+
+ ### resume pretrained model
+ if os.path.exists(args.resume):
+ print 'resume from model ' + args.resume
+ function.load_model(p_dict, args.resume)
+ print 'best_metric', p_dict['best_metric']
+ # return
+
+
+ if args.phase == 'train':
+
+ best_f1score = 0
+ for epoch in range(p_dict['epoch'] + 1, args.epochs):
+ p_dict['epoch'] = epoch
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = args.lr
+ train_eval(p_dict, 'train')
+ train_eval(p_dict, 'val')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/code/preprocessing/generate_sepsis_variables.py b/code/preprocessing/generate_sepsis_variables.py
new file mode 100644
index 0000000..0c47361
--- /dev/null
+++ b/code/preprocessing/generate_sepsis_variables.py
@@ -0,0 +1,366 @@
+
+#!/usr/bin/env python
+# coding=utf-8
+
+
+import sys
+
+import os
+import sys
+import time
+import numpy as np
+from sklearn import metrics
+import random
+import json
+from glob import glob
+from collections import OrderedDict
+from tqdm import tqdm
+
+sys.path.append('../tools')
+import parse, py_op
+
+args = parse.args
+
+
+variable_map_dict = {
+ # lab
+ 'WBC': 'wbc',
+ 'bun': 'bun',
+ 'sodium': 'sodium',
+ 'pt': 'pt',
+ 'INR': 'inr',
+ 'PTT': 'ptt',
+ 'platelet': 'platelet',
+ 'lactate' : 'lactate',
+ 'hemoglobin': 'hemoglobin',
+ 'glucose': 'glucose',
+ 'chloride': 'chloride',
+ 'creatinine': 'creatinine',
+ 'aniongap': 'aniongap',
+ 'bicarbonate': 'bicarbonate',
+
+ # other lab
+ 'hematocrit': 'hematocrit',
+
+ # used
+ 'heart rate': 'heartrate',
+ 'respiratory rate': 'resprate',
+ 'temperature': 'tempc',
+ 'meanbp': 'meanbp',
+ 'gcs': 'gcs_min',
+ 'urineoutput': 'urineoutput',
+ 'sysbp': 'sysbp',
+ 'diasbp': 'diasbp',
+ 'spo2': 'spo2',
+ 'Magnesium': '',
+
+ 'C-reactive protein': '',
+ 'bands': 'bands',
+ }
+
+item_id_dict = {
+ 'C-reactive protein': '50889',
+ 'Magnesium': '50960',
+ }
+
+def time_to_second(t):
+ t = str(t).replace('"', '')
+ t = time.mktime(time.strptime(t,'%Y-%m-%d %H:%M:%S'))
+ return int(t)
+
+def select_records_of_variables_not_in_pivoted():
+ count_dict = { v:0 for v in item_id_dict.values() }
+ hadm_time_dict = py_op.myreadjson(os.path.join(args.data_dir, args.dataset, 'hadm_time_dict.json' ))
+ wf = open(os.path.join(args.mimic_dir, 'sepsis_lab.csv'), 'w')
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, 'LABEVENTS.csv'))):
+ if i_line:
+ line_data = line.split(',')
+ if len(line_data) == 0:
+ continue
+ hadm_id, item_id, ctime = line_data[2:5]
+ value = line_data[5]
+ if item_id in count_dict and hadm_id in hadm_time_dict:
+ # print(line)
+ if len(line_data) != 9:
+ print(line)
+ # assert len(line_data) == 9
+ count_dict[item_id] += 1
+ wf.write(line)
+ else:
+ wf.write(line)
+ continue
+ if i_line % 10000 == 0:
+ print(i_line)
+ wf.close()
+
+
+
+def generate_variables_not_in_pivoted():
+ assert args.dataset == 'MIMIC'
+ id_item_dict = { v:k for k,v in item_id_dict.items() }
+ head = sorted(item_id_dict)
+ count_dict = { v:0 for v in item_id_dict.values() }
+ wf = open(os.path.join(args.mimic_dir, 'pivoted_add.csv'), 'w')
+ wf.write(','.join(['hadm_id', 'charttime'] + head) + '\n')
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, 'sepsis_lab.csv'))):
+ if i_line:
+ line_data = py_op.csv_split(line)
+ hadm_id, item_id, ctime = line_data[2:5]
+ value = line_data[6]
+ try:
+ value = float(value)
+ index = head.index(id_item_dict[item_id])
+ new_line = [hadm_id, ctime] + ['' for _ in range(index)] + [str(value)] + ['' for _ in range(index, len(head)-1)]
+ new_line = ','.join(new_line) + '\n'
+ wf.write(new_line)
+ except:
+ continue
+ count_dict[item_id] += 1
+ last_time = ctime
+ else:
+ print(line)
+ print(count_dict)
+
+def merge_pivoted_data(csv_list):
+ name_list = ['hadm_id', 'charttime']
+ for k,v in variable_map_dict.items():
+ if k not in ['age', 'gender']:
+ if len(v):
+ name_list.append(v)
+ elif k in item_id_dict:
+ name_list.append(k)
+ name_index_dict = { name:id for id,name in enumerate(name_list) }
+
+ hadm_time_dict = py_op.myreadjson(os.path.join(args.data_dir, args.dataset, 'hadm_time_dict.json' ))
+ icu_hadm_dict = py_op.myreadjson(os.path.join(args.data_dir, args.dataset, 'icu_hadm_dict.json' ))
+ merge_dir = os.path.join(args.data_dir, args.dataset, 'merge_pivoted')
+ os.system('rm -r ' + merge_dir)
+ os.system('mkdir ' + merge_dir)
+ pivoted_dir = os.path.join(args.result_dir, 'mimic/pivoted_sofa')
+ py_op.mkdir(pivoted_dir)
+
+ for fi in csv_list:
+ print(fi)
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, fi))):
+ if i_line:
+ line_data = line.strip().split(',')
+ if len(line_data) <= 0:
+ continue
+ line_dict = dict()
+ for iv, v in enumerate(line_data):
+ if len(v.strip()):
+ name = head[iv]
+ line_dict[name] = v
+
+ if fi == 'pivoted_sofa.csv':
+ icu_id = line_dict.get('icustay_id', 'xxx')
+ if icu_id not in icu_hadm_dict:
+ continue
+ hadm_id = str(icu_hadm_dict[icu_id])
+ line_dict['hadm_id'] = hadm_id
+ line_dict['charttime'] = line_dict['starttime']
+
+
+ hadm_id = line_dict.get('hadm_id', 'xxx')
+ if hadm_id not in hadm_time_dict:
+ continue
+ hadm_time = time_to_second(hadm_time_dict[hadm_id])
+ now_time = time_to_second(line_dict['charttime'])
+ delta_hour = int((now_time - hadm_time) / 3600)
+ line_dict['charttime'] = str(delta_hour)
+
+ if fi == 'pivoted_sofa.csv':
+ sofa_file = os.path.join(pivoted_dir, hadm_id + '.csv')
+ if not os.path.exists(sofa_file):
+ with open(sofa_file, 'w') as f:
+ f.write(sofa_head)
+ wf = open(sofa_file, 'a')
+ sofa_line = [str(delta_hour)] + line.split(',')[4:]
+ wf.write(','.join(sofa_line))
+ wf.close()
+
+
+ assert 'hadm_id' in line_dict
+ assert 'charttime' in line_dict
+ new_line = []
+ for name in name_list:
+ new_line.append(line_dict.get(name, ''))
+ new_line = ','.join(new_line) + '\n'
+ hadm_file = os.path.join(merge_dir, hadm_id + '.csv')
+ if not os.path.exists(hadm_file):
+ with open(hadm_file, 'w') as f:
+ f.write(','.join(name_list) + '\n')
+ wf = open(hadm_file, 'a')
+ wf.write(new_line)
+ wf.close()
+
+ else:
+ if fi == 'pivoted_sofa.csv':
+ sofa_head = ','.join(['time'] + line.replace('"', '').split(',')[4:])
+ # "icustay_id","hr","starttime","endtime","pao2fio2ratio_novent","pao2fio2ratio_vent","rate_epinephrine","rate_norepinephrine","rate_dopamine","rate_dobutamine","meanbp_min","gcs_min","urineoutput","bilirubin_max","creatinine_max","platelet_min","respiration","coagulation","liver","cardiovascular","cns","renal","respiration_24hours","coagulation_24hours","liver_24hours","cardiovascular_24hours","cns_24hours","renal_24hours","sofa_24hours"
+
+
+ head = line.replace('"', '').strip().split(',')
+ head = [h.strip() for h in head]
+ # print(line)
+ for h in head:
+ if h not in name_index_dict:
+ print(h)
+
+def sort_pivoted_data():
+ sort_dir = os.path.join(args.data_dir, args.dataset, 'sort_pivoted')
+ os.system('rm -r ' + sort_dir)
+ os.system('mkdir ' + sort_dir)
+ merge_dir = os.path.join(args.data_dir, args.dataset, 'merge_pivoted')
+
+ for i_fi, fi in enumerate(tqdm(os.listdir(merge_dir))):
+ wf = open(os.path.join(sort_dir, fi), 'w')
+ time_line_dict = dict()
+ for i_line, line in enumerate(open(os.path.join(merge_dir, fi))):
+ if i_line:
+ line_data = line.strip().split(',')
+ delta = 3
+ ctime = delta * int(int(line_data[1]) / delta)
+ if ctime not in time_line_dict:
+ time_line_dict[ctime] = []
+ time_line_dict[ctime].append(line_data)
+ else:
+ line_data = line.split(',')[1:]
+ line_data[0] = 'time'
+ wf.write(','.join(line_data))
+ for t in sorted(time_line_dict):
+ line_list = time_line_dict[t]
+ new_line = line_list[0]
+ for line_data in line_list[1:]:
+ for iv, v in enumerate(line_data):
+ if len(v.strip()):
+ new_line[iv] = v
+ new_line = ','.join(new_line[1:]) + '\n'
+ wf.write(new_line)
+ wf.close()
+ py_op.mkdir('../../data/MIMIC/train_groundtruth')
+ py_op.mkdir('../../data/MIMIC/train_with_missing')
+ os.system('rm ../../data/MIMIC/train_groundtruth/*.csv')
+ os.system('cp ../../data/MIMIC/sort_pivoted/* ../../data/MIMIC/train_groundtruth/')
+
+def generate_icu_mortality_dict(icustay_id_list):
+ icu_mortality_dict = dict()
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, 'sepsis_mortality.csv'))):
+ if i_line:
+ if i_line % 10000 == 0:
+ print(i_line)
+ line_data = line.strip().split(',')
+ icustay_id = line_data[0]
+ icu_mortality_dict[icustay_id] = int(line_data[-1])
+ py_op.mywritejson(os.path.join(args.data_dir, 'icu_mortality_dict.json'), icu_mortality_dict)
+
+
+def generate_lab_missing_values():
+ lab_files = glob(os.path.join(args.data_dir, args.dataset, 'train_groundtruth/*.csv'))
+ os.system('rm -r {:s}/*'.format(os.path.join(args.data_dir, args.dataset, 'train_with_missing')))
+ feat_count_dict = dict()
+ line_count_dict = dict()
+ n_full = 0
+ for i_fi, fi in enumerate(tqdm(lab_files)):
+ file_data = []
+ valid_list = []
+ last_data = [-10000]
+ for i_line, line in enumerate(open(fi)):
+ if i_line:
+ data = line.strip().split(',')
+ # print(data)
+ # print(fi)
+ # assert(int(data[0])) > -200
+ # assert(int(data[0])) < 800
+ if int(data[0]) < -24 or int(data[0]) >= 500:
+ continue
+ assert int(data[0]) > -200
+ valid = []
+ for i in range(len(data)):
+ feat_count_dict[feat_list[i]][0] += 1
+ if data[i] in ['', 'NA']:
+ feat_count_dict[feat_list[i]][1] += 1
+ valid.append(0)
+ else:
+ valid.append(1)
+ vector[i] = 1
+
+ if data[0] == last_data[0]:
+ for iv in range(len(data)):
+ if valid[iv]:
+ last_valid[iv] = valid[iv]
+ last_data[iv] = data[iv]
+ valid_list[-1] = last_valid
+ file_data[-1] = last_data
+ else:
+ valid_list.append(valid)
+ assert int(data[0]) < 700
+ assert int(data[0]) > - 200
+ file_data.append(data)
+ last_data = data
+ last_valid = valid
+ else:
+ feat_list = line.strip().split(',')
+ vector = [0 for _ in feat_list]
+ for feat in feat_list:
+ if feat not in feat_count_dict:
+ feat_count_dict[feat] = [0, 0]
+
+ valid_list.append([1 for _ in feat_list])
+ file_data.append(feat_list)
+ line_count_dict[i_line] = line_count_dict.get(i_line, 0) + 1
+
+ vs = [0 for _ in file_data[0]]
+ for data in file_data[1:]:
+ for iv, v in enumerate(data):
+ if v.strip() not in ['', 'NA']:
+ vs[iv] += 1
+ if np.min(vs) >= 2:
+ n_full +=1
+ # continue
+
+
+ # if len(file_data)< 15 or np.min(vs) < 2:
+ if len(file_data)< 5 or sorted(vector)[2] < 1:
+ os.system('rm -r ' + fi)
+ # os.system('rm -r ' + fi.replace('groundtruth', 'with_missing'))
+ # print('rm -r ' + fi.replace('groundtruth', 'with_missing'))
+ else:
+ for data in file_data[1:]:
+ assert int(data[0]) > -200
+ # write groundtruth data
+ x = [','.join(line) for line in file_data]
+ x = '\n'.join(x)
+ with open(fi, 'w') as f:
+ f.write(x)
+
+ valid_list = np.array(valid_list)
+ valid_list[0] = 0
+ for i in range(1, valid_list.shape[1]):
+ valid = valid_list[:, i]
+ indices = np.where(valid > 0)[0]
+ indices = sorted(indices)
+ if len(indices) > 2:
+ indices = indices[1:-1]
+ np.random.shuffle(indices)
+ file_data[indices[0]][i] = ''
+ # write groundtruth data
+ x = [','.join(line) for line in file_data]
+ x = '\n'.join(x)
+ with open(fi.replace('groundtruth', 'with_missing'), 'w') as f:
+ f.write(x)
+ print(n_full)
+
+
+def main():
+ csv_list = ['pivoted_sofa.csv', 'pivoted_add.csv', 'pivoted_lab.csv', 'pivoted_vital.csv']
+ select_records_of_variables_not_in_pivoted()
+ generate_variables_not_in_pivoted()
+ merge_pivoted_data(csv_list)
+ sort_pivoted_data()
+ generate_lab_missing_values()
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/code/preprocessing/generate_value_distribution.py b/code/preprocessing/generate_value_distribution.py
new file mode 100644
index 0000000..872d0e4
--- /dev/null
+++ b/code/preprocessing/generate_value_distribution.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+
+import sys
+reload(sys)
+sys.setdefaultencoding('utf8')
+
+import os
+import sys
+import time
+import numpy as np
+from sklearn import metrics
+import random
+import json
+from glob import glob
+from collections import OrderedDict
+from tqdm import tqdm
+
+
+sys.path.append('../tools')
+import parse, py_op
+args = parse.args
+
+def generate_feature_mm_dict():
+ files = sorted(glob(os.path.join(args.data_dir, args.dataset, 'train_groundtruth/*')))
+ feature_value_dict = dict()
+ for ifi, fi in enumerate(tqdm(files)):
+ if 'csv' not in fi:
+ continue
+ for iline, line in enumerate(open(fi)):
+ line = line.strip()
+ if iline == 0:
+ feat_list = line.split(',')
+ else:
+ data = line.split(',')
+ for iv, v in enumerate(data):
+ if v in ['NA', '']:
+ continue
+ else:
+ feat = feat_list[iv]
+ if feat not in feature_value_dict:
+ feature_value_dict[feat] = []
+ feature_value_dict[feat].append(float(v))
+ feature_mm_dict = dict()
+ feature_ms_dict = dict()
+
+ feature_range_dict = dict()
+ for feat, vs in feature_value_dict.items():
+ vs = sorted(vs)
+ value_split = []
+ for i in range(args.split_num):
+ n = int(i * len(vs) / args.split_num)
+ value_split.append(vs[n])
+ value_split.append(vs[-1])
+ feature_range_dict[feat] = value_split
+
+
+ n = int(len(vs) / args.split_num)
+ feature_mm_dict[feat] = [vs[n], vs[-n - 1]]
+ feature_ms_dict[feat] = [np.mean(vs), np.std(vs)]
+
+ py_op.mkdir(args.file_dir)
+ py_op.mywritejson(os.path.join(args.file_dir, args.dataset + '_feature_mm_dict.json'), feature_mm_dict)
+ py_op.mywritejson(os.path.join(args.file_dir, args.dataset + '_feature_ms_dict.json'), feature_ms_dict)
+ py_op.mywritejson(os.path.join(args.file_dir, args.dataset + '_feature_list.json'), feat_list)
+ py_op.mywritejson(os.path.join(args.file_dir, args.dataset + '_feature_value_dict_{:d}.json'.format(args.split_num)), feature_range_dict)
+
+def split_data_to_ten_set():
+ files = sorted(glob(os.path.join(args.data_dir, args.dataset, 'train_with_missing/*')))
+ np.random.shuffle(files)
+ splits = []
+ for i in range(10):
+ st = int(len(files) * i / 10)
+ en = int(len(files) * (i+1) / 10)
+ splits.append(files[st:en])
+ py_op.mywritejson(os.path.join(args.file_dir, args.dataset + '_splits.json'), splits)
+
+
+def main():
+ generate_feature_mm_dict()
+ split_data_to_ten_set()
+
+if __name__ == '__main__':
+ main()
diff --git a/code/preprocessing/pivoted_file_generation.md b/code/preprocessing/pivoted_file_generation.md
new file mode 100644
index 0000000..6f57120
--- /dev/null
+++ b/code/preprocessing/pivoted_file_generation.md
@@ -0,0 +1,334 @@
+# SQL for pivoted\_\*.csv generation
+
+```
+-- Drop table
+
+-- DROP TABLE public.pivoted_sofa;
+
+CREATE TABLE public.pivoted_sofa (
+ icustay_id int4 NULL,
+ hr int4 NULL,
+ starttime timestamp NULL,
+ endtime timestamp NULL,
+ pao2fio2ratio_novent float8 NULL,
+ pao2fio2ratio_vent float8 NULL,
+ rate_epinephrine float8 NULL,
+ rate_norepinephrine float8 NULL,
+ rate_dopamine float8 NULL,
+ rate_dobutamine float8 NULL,
+ meanbp_min float8 NULL,
+ gcs_min float8 NULL,
+ urineoutput float8 NULL,
+ bilirubin_max float8 NULL,
+ creatinine_max float8 NULL,
+ platelet_min float8 NULL,
+ respiration int2 NULL,
+ coagulation int2 NULL,
+ liver int2 NULL,
+ cardiovascular int2 NULL,
+ cns int2 NULL,
+ renal int2 NULL,
+ respiration_24hours int2 NULL,
+ coagulation_24hours int2 NULL,
+ liver_24hours int2 NULL,
+ cardiovascular_24hours int2 NULL,
+ cns_24hours int2 NULL,
+ renal_24hours int2 NULL,
+ sofa_24hours int2 NULL
+);
+
+
+
+
+
+
+CREATE MATERIALIZED VIEW public.pivoted_lab
+TABLESPACE pg_default
+AS WITH i AS (
+ SELECT icustays.subject_id,
+ icustays.icustay_id,
+ icustays.intime,
+ icustays.outtime,
+ lag(icustays.outtime) OVER (PARTITION BY icustays.subject_id ORDER BY icustays.intime) AS outtime_lag,
+ lead(icustays.intime) OVER (PARTITION BY icustays.subject_id ORDER BY icustays.intime) AS intime_lead
+ FROM icustays
+ ), iid_assign AS (
+ SELECT i.subject_id,
+ i.icustay_id,
+ CASE
+ WHEN i.outtime_lag IS NOT NULL AND i.outtime_lag > (i.intime - '24:00:00'::interval hour) THEN i.intime - (i.intime - i.outtime_lag) / 2::double precision
+ ELSE i.intime - '12:00:00'::interval hour
+ END AS data_start,
+ CASE
+ WHEN i.intime_lead IS NOT NULL AND i.intime_lead < (i.outtime + '24:00:00'::interval hour) THEN i.outtime + (i.intime_lead - i.outtime) / 2::double precision
+ ELSE i.outtime + '12:00:00'::interval hour
+ END AS data_end
+ FROM i
+ ), h AS (
+ SELECT admissions.subject_id,
+ admissions.hadm_id,
+ admissions.admittime,
+ admissions.dischtime,
+ lag(admissions.dischtime) OVER (PARTITION BY admissions.subject_id ORDER BY admissions.admittime) AS dischtime_lag,
+ lead(admissions.admittime) OVER (PARTITION BY admissions.subject_id ORDER BY admissions.admittime) AS admittime_lead
+ FROM admissions
+ ), adm AS (
+ SELECT h.subject_id,
+ h.hadm_id,
+ CASE
+ WHEN h.dischtime_lag IS NOT NULL AND h.dischtime_lag > (h.admittime - '24:00:00'::interval hour) THEN h.admittime - (h.admittime - h.dischtime_lag) / 2::double precision
+ ELSE h.admittime - '12:00:00'::interval hour
+ END AS data_start,
+ CASE
+ WHEN h.admittime_lead IS NOT NULL AND h.admittime_lead < (h.dischtime + '24:00:00'::interval hour) THEN h.dischtime + (h.admittime_lead - h.dischtime) / 2::double precision
+ ELSE h.dischtime + '12:00:00'::interval hour
+ END AS data_end
+ FROM h
+ ), le AS (
+ SELECT labevents.subject_id,
+ labevents.charttime,
+ CASE
+ WHEN labevents.itemid = 50868 THEN 'ANION GAP'::text
+ WHEN labevents.itemid = 50862 THEN 'ALBUMIN'::text
+ WHEN labevents.itemid = 51144 THEN 'BANDS'::text
+ WHEN labevents.itemid = 50882 THEN 'BICARBONATE'::text
+ WHEN labevents.itemid = 50885 THEN 'BILIRUBIN'::text
+ WHEN labevents.itemid = 50912 THEN 'CREATININE'::text
+ WHEN labevents.itemid = 50902 THEN 'CHLORIDE'::text
+ WHEN labevents.itemid = 50931 THEN 'GLUCOSE'::text
+ WHEN labevents.itemid = 51221 THEN 'HEMATOCRIT'::text
+ WHEN labevents.itemid = 51222 THEN 'HEMOGLOBIN'::text
+ WHEN labevents.itemid = 50813 THEN 'LACTATE'::text
+ WHEN labevents.itemid = 51265 THEN 'PLATELET'::text
+ WHEN labevents.itemid = 50971 THEN 'POTASSIUM'::text
+ WHEN labevents.itemid = 51275 THEN 'PTT'::text
+ WHEN labevents.itemid = 51237 THEN 'INR'::text
+ WHEN labevents.itemid = 51274 THEN 'PT'::text
+ WHEN labevents.itemid = 50983 THEN 'SODIUM'::text
+ WHEN labevents.itemid = 51006 THEN 'BUN'::text
+ WHEN labevents.itemid = 51300 THEN 'WBC'::text
+ WHEN labevents.itemid = 51301 THEN 'WBC'::text
+ ELSE NULL::text
+ END AS label,
+ CASE
+ WHEN labevents.itemid = 50862 AND labevents.valuenum > 10::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50868 AND labevents.valuenum > 10000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51144 AND labevents.valuenum < 0::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51144 AND labevents.valuenum > 100::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50882 AND labevents.valuenum > 10000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50885 AND labevents.valuenum > 150::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50806 AND labevents.valuenum > 10000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50902 AND labevents.valuenum > 10000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50912 AND labevents.valuenum > 150::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50809 AND labevents.valuenum > 10000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50931 AND labevents.valuenum > 10000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50810 AND labevents.valuenum > 100::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51221 AND labevents.valuenum > 100::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50811 AND labevents.valuenum > 50::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51222 AND labevents.valuenum > 50::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50813 AND labevents.valuenum > 50::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51265 AND labevents.valuenum > 10000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50822 AND labevents.valuenum > 30::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50971 AND labevents.valuenum > 30::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51275 AND labevents.valuenum > 150::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51237 AND labevents.valuenum > 50::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51274 AND labevents.valuenum > 150::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50824 AND labevents.valuenum > 200::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 50983 AND labevents.valuenum > 200::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51006 AND labevents.valuenum > 300::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51300 AND labevents.valuenum > 1000::double precision THEN NULL::double precision
+ WHEN labevents.itemid = 51301 AND labevents.valuenum > 1000::double precision THEN NULL::double precision
+ ELSE labevents.valuenum
+ END AS valuenum
+ FROM labevents
+ WHERE (labevents.itemid = ANY (ARRAY[50868, 50862, 51144, 50882, 50885, 50912, 50902, 50931, 51221, 51222, 50813, 51265, 50971, 51275, 51237, 51274, 50983, 51006, 51301, 51300])) AND labevents.valuenum IS NOT NULL AND labevents.valuenum > 0::double precision
+ ), le_avg AS (
+ SELECT le.subject_id,
+ le.charttime,
+ avg(
+ CASE
+ WHEN le.label = 'ANION GAP'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS aniongap,
+ avg(
+ CASE
+ WHEN le.label = 'ALBUMIN'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS albumin,
+ avg(
+ CASE
+ WHEN le.label = 'BANDS'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS bands,
+ avg(
+ CASE
+ WHEN le.label = 'BICARBONATE'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS bicarbonate,
+ avg(
+ CASE
+ WHEN le.label = 'BILIRUBIN'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS bilirubin,
+ avg(
+ CASE
+ WHEN le.label = 'CREATININE'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS creatinine,
+ avg(
+ CASE
+ WHEN le.label = 'CHLORIDE'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS chloride,
+ avg(
+ CASE
+ WHEN le.label = 'GLUCOSE'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS glucose,
+ avg(
+ CASE
+ WHEN le.label = 'HEMATOCRIT'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS hematocrit,
+ avg(
+ CASE
+ WHEN le.label = 'HEMOGLOBIN'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS hemoglobin,
+ avg(
+ CASE
+ WHEN le.label = 'LACTATE'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS lactate,
+ avg(
+ CASE
+ WHEN le.label = 'PLATELET'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS platelet,
+ avg(
+ CASE
+ WHEN le.label = 'POTASSIUM'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS potassium,
+ avg(
+ CASE
+ WHEN le.label = 'PTT'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS ptt,
+ avg(
+ CASE
+ WHEN le.label = 'INR'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS inr,
+ avg(
+ CASE
+ WHEN le.label = 'PT'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS pt,
+ avg(
+ CASE
+ WHEN le.label = 'SODIUM'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS sodium,
+ avg(
+ CASE
+ WHEN le.label = 'BUN'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS bun,
+ avg(
+ CASE
+ WHEN le.label = 'WBC'::text THEN le.valuenum
+ ELSE NULL::double precision
+ END) AS wbc
+ FROM le
+ GROUP BY le.subject_id, le.charttime
+ )
+ SELECT iid.icustay_id,
+ adm.hadm_id,
+ le_avg.subject_id,
+ le_avg.charttime,
+ le_avg.aniongap,
+ le_avg.albumin,
+ le_avg.bands,
+ le_avg.bicarbonate,
+ le_avg.bilirubin,
+ le_avg.creatinine,
+ le_avg.chloride,
+ le_avg.glucose,
+ le_avg.hematocrit,
+ le_avg.hemoglobin,
+ le_avg.lactate,
+ le_avg.platelet,
+ le_avg.potassium,
+ le_avg.ptt,
+ le_avg.inr,
+ le_avg.pt,
+ le_avg.sodium,
+ le_avg.bun,
+ le_avg.wbc
+ FROM le_avg
+ LEFT JOIN adm ON le_avg.subject_id = adm.subject_id AND le_avg.charttime >= adm.data_start AND le_avg.charttime < adm.data_end
+ LEFT JOIN iid_assign iid ON le_avg.subject_id = iid.subject_id AND le_avg.charttime >= iid.data_start AND le_avg.charttime < iid.data_end
+ ORDER BY le_avg.subject_id, le_avg.charttime
+WITH DATA;
+
+
+CREATE MATERIALIZED VIEW public.pivoted_vital
+TABLESPACE pg_default
+AS WITH ce AS (
+ SELECT ce_1.icustay_id,
+ ce_1.charttime,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[211, 220045])) AND ce_1.valuenum > 0::double precision AND ce_1.valuenum < 300::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS heartrate,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[51, 442, 455, 6701, 220179, 220050])) AND ce_1.valuenum > 0::double precision AND ce_1.valuenum < 400::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS sysbp,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[8368, 8440, 8441, 8555, 220180, 220051])) AND ce_1.valuenum > 0::double precision AND ce_1.valuenum < 300::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS diasbp,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[456, 52, 6702, 443, 220052, 220181, 225312])) AND ce_1.valuenum > 0::double precision AND ce_1.valuenum < 300::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS meanbp,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[615, 618, 220210, 224690])) AND ce_1.valuenum > 0::double precision AND ce_1.valuenum < 70::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS resprate,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[223761, 678])) AND ce_1.valuenum > 70::double precision AND ce_1.valuenum < 120::double precision THEN (ce_1.valuenum - 32::double precision) / 1.8::double precision
+ WHEN (ce_1.itemid = ANY (ARRAY[223762, 676])) AND ce_1.valuenum > 10::double precision AND ce_1.valuenum < 50::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS tempc,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[646, 220277])) AND ce_1.valuenum > 0::double precision AND ce_1.valuenum <= 100::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS spo2,
+ CASE
+ WHEN (ce_1.itemid = ANY (ARRAY[807, 811, 1529, 3745, 3744, 225664, 220621, 226537])) AND ce_1.valuenum > 0::double precision THEN ce_1.valuenum
+ ELSE NULL::double precision
+ END AS glucose
+ FROM chartevents ce_1
+ WHERE ce_1.error IS DISTINCT FROM 1 AND (ce_1.itemid = ANY (ARRAY[211, 220045, 51, 442, 455, 6701, 220179, 220050, 8368, 8440, 8441, 8555, 220180, 220051, 456, 52, 6702, 443, 220052, 220181, 225312, 618, 615, 220210, 224690, 646, 220277, 807, 811, 1529, 3745, 3744, 225664, 220621, 226537, 223762, 676, 223761, 678]))
+ )
+ SELECT ce.icustay_id,
+ ce.charttime,
+ avg(ce.heartrate) AS heartrate,
+ avg(ce.sysbp) AS sysbp,
+ avg(ce.diasbp) AS diasbp,
+ avg(ce.meanbp) AS meanbp,
+ avg(ce.resprate) AS resprate,
+ avg(ce.tempc) AS tempc,
+ avg(ce.spo2) AS spo2,
+ avg(ce.glucose) AS glucose
+ FROM ce
+ GROUP BY ce.icustay_id, ce.charttime
+ ORDER BY ce.icustay_id, ce.charttime
+WITH DATA;
+
+
+
+```
diff --git a/code/preprocessing/preprocess_mimic_data.py b/code/preprocessing/preprocess_mimic_data.py
new file mode 100644
index 0000000..dbe46b6
--- /dev/null
+++ b/code/preprocessing/preprocess_mimic_data.py
@@ -0,0 +1,199 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+
+import sys
+reload(sys)
+sys.setdefaultencoding('utf8')
+
+import os
+import sys
+import time
+import numpy as np
+from sklearn import metrics
+import random
+import json
+from glob import glob
+from collections import OrderedDict
+from tqdm import tqdm
+
+sys.path.append('../tools')
+import parse, py_op
+
+args = parse.args
+args.data_dir = os.path.join(args.data_dir, args.dataset)
+
+def time_to_second(t):
+ t = str(t).replace('"', '')
+ t = time.mktime(time.strptime(t,'%Y-%m-%d %H:%M:%S'))
+ return int(t)
+
+def map_ehr_id():
+ print('start')
+ ehr_count_dict = py_op.myreadjson(os.path.join(args.data_dir, 'ehr_count_dict.json'))
+ ehr_list = [ehr for ehr,c in ehr_count_dict.items() if c > 100]
+ ns = set('0123456789')
+ print(ns)
+ drug_list = [e for e in ehr_list if e[1] in ns]
+ med_list = [e for e in ehr_list if e[1] not in ns]
+ print(len(drug_list))
+ print(len(med_list))
+ py_op.mywritejson(os.path.join(args.data_dir, 'ehr_list.json'), ehr_list)
+
+
+def generate_ehr_files():
+
+ hadm_time_dict = py_op.myreadjson(os.path.join(args.data_dir, 'hadm_time_dict.json'))
+ hadm_demo_dict = py_op.myreadjson(os.path.join(args.data_dir, 'hadm_demo_dict.json'))
+ hadm_sid_dict = py_op.myreadjson(os.path.join(args.data_dir, 'hadm_sid_dict.json'))
+ hadm_icd_dict = py_op.myreadjson(os.path.join(args.data_dir, 'hadm_icd_dict.json'))
+ hadm_time_drug_dict = py_op.myreadjson(os.path.join(args.data_dir, 'hadm_time_drug_dict.json'))
+ groundtruth_dir = os.path.join(args.data_dir, 'train_groundtruth')
+ py_op.mkdir(groundtruth_dir)
+ ehr_count_dict = dict()
+
+ for hadm_id in hadm_sid_dict:
+
+ time_drug_dict = hadm_time_drug_dict.get(hadm_id, { })
+ icd_list = hadm_icd_dict.get(hadm_id, [])
+ demo = hadm_demo_dict[hadm_id]
+ demo[0] = demo[0] + '1'
+ demo[1] = 'A' + str(int(demo[1] / 9))
+ icd_demo = icd_list + demo
+
+ for icd in icd_demo:
+ ehr_count_dict[icd] = ehr_count_dict.get(icd, 0) + 1
+
+
+
+ ehr_dict = { 'drug':{ }, 'icd_demo': icd_demo}
+
+ for setime, drug_list in time_drug_dict.items():
+ try:
+ stime,etime = setime.split(' -- ')
+ start_second = time_to_second(hadm_time_dict[hadm_id])
+ stime = str((time_to_second(stime) - start_second) / 3600)
+ etime = str((time_to_second(etime) - start_second) / 3600)
+ setime = stime + ' -- ' + etime
+ for drug in drug_list:
+ ehr_count_dict[drug] = ehr_count_dict.get(drug, 0) + 1
+ ehr_dict['drug'][setime] = list(set(drug_list))
+ except:
+ pass
+
+
+ py_op.mywritejson(os.path.join(groundtruth_dir, hadm_id + '.json'), ehr_dict)
+ # break
+ py_op.mywritejson(os.path.join(args.data_dir, 'ehr_count_dict.json'), ehr_count_dict)
+
+
+def generate_demo():
+ icu_hadm_dict = py_op.myreadjson('../../src/icu_hadm_dict.json')
+ py_op.mywritejson(os.path.join(args.data_dir, 'icu_hadm_dict.json'), icu_hadm_dict)
+
+ sid_demo_dict = dict()
+ sid_hadm_dict = dict()
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, 'PATIENTS.csv'))):
+ if i_line:
+ data = line.split(',')
+ sid = data[1]
+ gender = data[2].replace('"', '')
+ dob = data[3][:4]
+ sid_demo_dict[sid] = [gender, int(dob)]
+ py_op.mywritejson(os.path.join(args.data_dir, 'sid_demo_dict.json'), sid_demo_dict)
+
+ hadm_sid_dict = dict()
+ hadm_demo_dict = dict()
+ hadm_time_dict = dict()
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, 'ICUSTAYS.csv'))):
+ if i_line:
+ line = line.replace('"', '')
+ data = line.split(',')
+ sid = data[1]
+ hadm_id = data[2]
+ icu_id = data[3]
+ intime = data[-3]
+ sid_hadm_dict[sid] = sid_hadm_dict.get(sid, []) + [hadm_id]
+ if icu_id not in icu_hadm_dict:
+ continue
+ hadm_sid_dict[hadm_id] = sid
+ gender = sid_demo_dict[sid][0]
+ dob = sid_demo_dict[sid][1]
+ age = int(intime[:4]) - dob
+ if age < 18:
+ print(age)
+ assert age >= 18
+ if age > 150:
+ age = 90
+ hadm_demo_dict[hadm_id] = [gender, age]
+ hadm_time_dict[hadm_id] = intime
+ py_op.mywritejson(os.path.join(args.data_dir, 'hadm_demo_dict.json'), hadm_demo_dict)
+ py_op.mywritejson(os.path.join(args.data_dir, 'hadm_time_dict.json'), hadm_time_dict)
+ py_op.mywritejson(os.path.join(args.data_dir, 'sid_hadm_dict.json'), sid_hadm_dict)
+ py_op.mywritejson(os.path.join(args.data_dir, 'hadm_sid_dict.json'), hadm_sid_dict)
+
+def generate_diagnosis_data():
+ sid_hadm_dict = py_op.myreadjson(os.path.join(args.data_dir, 'sid_hadm_dict.json') )
+ hadm_sid_dict = py_op.myreadjson(os.path.join(args.data_dir, 'hadm_sid_dict.json'))
+
+ hadm_map_dict = dict()
+ for hadm in hadm_sid_dict:
+ sid = hadm_sid_dict[hadm]
+ hadm_list = sid_hadm_dict[sid]
+ if len(hadm_list) > 1:
+ hadm_list = sorted(hadm_list, key=lambda k:int(k))
+ idx = hadm_list.index(hadm)
+ if idx > 0:
+ for h in hadm_list[:idx]:
+ if h not in hadm_map_dict:
+ hadm_map_dict[h] = []
+ hadm_map_dict[h].append(hadm)
+
+ hadm_icd_dict = dict()
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, 'DIAGNOSES_ICD.csv'))):
+ if i_line:
+ if i_line % 10000 == 0:
+ print(i_line)
+ line_data = [x.strip('"') for x in py_op.csv_split(line.strip())]
+ ROW_ID, SUBJECT_ID, hadm_id, SEQ_NUM, icd = line_data
+ if hadm_id in hadm_map_dict:
+ for h in hadm_map_dict[hadm_id]:
+ if h not in hadm_icd_dict:
+ hadm_icd_dict[h] = []
+ hadm_icd_dict[h].append(icd)
+ hadm_icd_dict = { h:list(set(icds)) for h, icds in hadm_icd_dict.items() }
+ py_op.mywritejson(os.path.join(args.data_dir, 'hadm_icd_dict.json'), hadm_icd_dict)
+
+def generate_drug_data():
+ hadm_sid_dict = py_op.myreadjson(os.path.join(args.data_dir, 'hadm_sid_dict.json'))
+ hadm_id_set = set(hadm_sid_dict)
+ hadm_time_drug_dict = dict()
+ for i_line, line in enumerate(open(os.path.join(args.mimic_dir, 'PRESCRIPTIONS.csv'))):
+ if i_line:
+ if i_line % 10000 == 0:
+ print(i_line)
+ line_data = [x.strip('"') for x in py_op.csv_split(line.strip())]
+ _, SUBJECT_ID,hadm_id,_,startdate,enddate,_,drug,DRUG_NAME_POE,DRUG_NAME_GENERIC,FORMULARY_DRUG_CD,gsn,ndc,PROD_STRENGTH,DOSE_VAL_RX,DOSE_UNIT_RX,FORM_VAL_DISP,FORM_UNIT_DISP,ROUTE = line_data
+ if len(hadm_id) and hadm_id in hadm_id_set:
+ if hadm_id not in hadm_time_drug_dict:
+ hadm_time_drug_dict[hadm_id] = dict()
+ time = startdate + ' -- ' + enddate
+ if time not in hadm_time_drug_dict[hadm_id]:
+ hadm_time_drug_dict[hadm_id][time] = []
+ hadm_time_drug_dict[hadm_id][time].append(drug)
+ # hadm_time_drug_dict[hadm_id][time].append(ndc)
+ py_op.mywritejson(os.path.join(args.data_dir, 'hadm_time_drug_dict.json'), hadm_time_drug_dict)
+
+
+
+
+def main():
+ generate_demo()
+
+ generate_diagnosis_data()
+ generate_drug_data()
+ generate_ehr_files()
+ map_ehr_id()
+
+if __name__ == '__main__':
+ main()
diff --git a/code/tools/mimic_op.py b/code/tools/mimic_op.py
new file mode 100644
index 0000000..45d7295
--- /dev/null
+++ b/code/tools/mimic_op.py
@@ -0,0 +1,14 @@
+# coding=utf8
+
+import parse
+args = parse.args
+
+def get_line_data(line):
+ pass
+
+def select_records_according_subjectid(subject_ids, input_file, output_file):
+ pass
+
+
+
+
diff --git a/code/tools/parse.py b/code/tools/parse.py
new file mode 100644
index 0000000..58ab14c
--- /dev/null
+++ b/code/tools/parse.py
@@ -0,0 +1,225 @@
+# coding=utf8
+
+import os
+import argparse
+
+parser = argparse.ArgumentParser(description='MIMIC III PROJECTS')
+
+# data dir
+parser.add_argument(
+ '--data-dir',
+ type=str,
+ default='../../data/',
+ help='selected and preprocessed data directory'
+ )
+parser.add_argument(
+ '--result-dir',
+ type=str,
+ default='../../result/',
+ help='result directory'
+ )
+parser.add_argument(
+ '--file-dir',
+ type=str,
+ default='../../file/',
+ help='useful file directory'
+ )
+parser.add_argument(
+ '--mimic-dir',
+ type=str,
+ default='../../data/MIMIC/initial_mimiciii/',
+ help='useful file directory'
+ )
+
+parser.add_argument(
+ '--dataset',
+ default='DACMI',
+ # default='MIMIC',
+ type=str,
+ help='dataset')
+
+parser.add_argument(
+ '--n-code',
+ default=8,
+ type=int,
+ help='at most n codes for same visit')
+parser.add_argument(
+ '--n-visit',
+ default=30,
+ type=int,
+ help='at most input n visits')
+parser.add_argument(
+ '--nc',
+ default=4,
+ type=int,
+ help='n clusters')
+parser.add_argument(
+ '--brnn',
+ default=True,
+ type=bool,
+ help='use bidirectional RNN or not')
+parser.add_argument(
+ '--random-missing',
+ default=True,
+ type=bool,
+ help='use random missing values for training')
+
+
+
+# method seetings
+parser.add_argument(
+ '--model',
+ '-m',
+ type=str,
+ default='tame',
+ help='model'
+ )
+parser.add_argument(
+ '--split-num',
+ metavar='split num',
+ type=int,
+ default=4000,
+ help='split num'
+ )
+parser.add_argument(
+ '--n-records',
+ metavar='input size',
+ type=int,
+ default=30,
+ help='input size'
+ )
+parser.add_argument(
+ '--split-nor',
+ metavar='split normal range',
+ type=int,
+ default=3,
+ help='split num'
+ )
+parser.add_argument(
+ '--use-ta',
+ metavar='use time-aware attention',
+ type=int,
+ default=1,
+ help='use time-aware attention'
+ )
+parser.add_argument(
+ '--use-ve',
+ metavar='use value embedding',
+ type=int,
+ default=1,
+ help='use value-embedding'
+ )
+parser.add_argument(
+ '--use-mm',
+ metavar='use multi-modal input',
+ type=int,
+ default=0,
+ help='use multi-modal input'
+ )
+parser.add_argument(
+ '--value-embedding',
+ metavar='use time embedding',
+ type=str,
+ # default='use_value',
+ default='use_order',
+ # default='no',
+ help='use_value or use_order or no'
+ )
+parser.add_argument(
+ '--loss',
+ type=str,
+ # default='missing',
+ # default='init',
+ default='both',
+ help='loss function, missing, init, both'
+ )
+
+
+# model parameters
+parser.add_argument(
+ '--embed-size',
+ metavar='EMBED SIZE',
+ type=int,
+ default=512,
+ help='embed size'
+ )
+parser.add_argument(
+ '--rnn-size',
+ metavar='rnn SIZE',
+ type=int,
+ help='rnn size'
+ )
+parser.add_argument(
+ '--hidden-size',
+ metavar='hidden SIZE',
+ type=int,
+ help='hidden size'
+ )
+parser.add_argument(
+ '--num-layers',
+ metavar='num layers',
+ type=int,
+ default=2,
+ help='num layers'
+ )
+
+
+
+# traing process setting
+parser.add_argument('--phase',
+ default='train',
+ type=str,
+ metavar='S',
+ help='pretrain/train/test phase')
+parser.add_argument(
+ '--batch-size',
+ '-b',
+ metavar='BATCH SIZE',
+ type=int,
+ default=64,
+ help='batch size'
+ )
+parser.add_argument('--resume',
+ default='',
+ type=str,
+ metavar='S',
+ help='start from checkpoints')
+parser.add_argument(
+ '--compute-weight',
+ default=0,
+ type=int,
+ help='compute weight for interpretebility')
+parser.add_argument(
+ '--workers',
+ default=16,
+ type=int,
+ metavar='N',
+ help='number of data loading workers (default: 32)')
+parser.add_argument('--lr',
+ '--learning-rate',
+ default=0.001,
+ type=float,
+ metavar='LR',
+ help='initial learning rate')
+parser.add_argument('--epochs',
+ default=2000,
+ type=int,
+ metavar='N',
+ help='number of total epochs to run')
+parser.add_argument('--save-freq',
+ default=1,
+ type=int,
+ metavar='S',
+ help='save frequency')
+parser.add_argument('--save-pred-freq',
+ default='10',
+ type=int,
+ metavar='S',
+ help='save pred clean frequency')
+parser.add_argument('--val-freq',
+ default=1,
+ type=int,
+ metavar='S',
+ help='val frequency')
+
+args = parser.parse_args()
diff --git a/code/tools/py_op.py b/code/tools/py_op.py
new file mode 100644
index 0000000..4885294
--- /dev/null
+++ b/code/tools/py_op.py
@@ -0,0 +1,134 @@
+# -*- coding: utf-8 -*-
+"""
+此文件用于常用python函数的使用
+"""
+import os
+import json
+import traceback
+from collections import OrderedDict
+import random
+
+import sys
+# reload(sys)
+# sys.setdefaultencoding('utf-8')
+
+################################################################################
+### pre define variables
+#:: enumerate
+#:: raw_input
+#:: listdir
+#:: sorted
+### pre define function
+def mywritejson(save_path,content):
+ content = json.dumps(content,indent=4,ensure_ascii=False)
+ with open(save_path,'w') as f:
+ f.write(content)
+
+def myreadjson(load_path):
+ with open(load_path,'r') as f:
+ return json.loads(f.read())
+
+def mywritefile(save_path,content):
+ with open(save_path,'w') as f:
+ f.write(content)
+
+def myreadfile(load_path):
+ with open(load_path,'r') as f:
+ return f.read()
+
+def myprint(content):
+ print(json.dumps(content,indent=4,ensure_ascii=False))
+
+def rm(fi):
+ os.system('rm ' + fi)
+
+def mystrip(s):
+ return ''.join(s.split())
+
+def mysorteddict(d,key = lambda s:s, reverse=False):
+ dordered = OrderedDict()
+ for k in sorted(d.keys(),key = key,reverse=reverse):
+ dordered[k] = d[k]
+ return dordered
+
+def mysorteddictfile(src,obj):
+ mywritejson(obj,mysorteddict(myreadjson(src)))
+
+def myfuzzymatch(srcs,objs,grade=80):
+ matchDict = OrderedDict()
+ for src in srcs:
+ for obj in objs:
+ value = fuzz.partial_ratio(src,obj)
+ if value > grade:
+ try:
+ matchDict[src].append(obj)
+ except:
+ matchDict[src] = [obj]
+ return matchDict
+
+def mydumps(x):
+ return json.dumps(content,indent=4,ensure_ascii=False)
+
+def get_random_list(l,num=-1,isunique=0):
+ if isunique:
+ l = set(l)
+ if num < 0:
+ num = len(l)
+ if isunique and num > len(l):
+ return
+ lnew = []
+ l = list(l)
+ while(num>len(lnew)):
+ x = l[int(random.random()*len(l))]
+ if isunique and x in lnew:
+ continue
+ lnew.append(x)
+ return lnew
+
+def fuzz_list(node1_list,node2_list,score_baseline=66,proposal_num=10,string_map=None):
+ node_dict = { }
+ for i,node1 in enumerate(node1_list):
+ match_score_dict = { }
+ for node2 in node2_list:
+ if node1 != node2:
+ if string_map is not None:
+ n1 = string_map(node1)
+ n2 = string_map(node2)
+ score = fuzz.partial_ratio(n1,n2)
+ if n1 == n2:
+ node2_list.remove(node2)
+ else:
+ score = fuzz.partial_ratio(node1,node2)
+ if score > score_baseline:
+ match_score_dict[node2] = score
+ else:
+ node2_list.remove(node2)
+ node2_sort = sorted(match_score_dict.keys(), key=lambda k:match_score_dict[k],reverse=True)
+ node_dict[node1] = [[n,match_score_dict[n]] for n in node2_sort[:proposal_num]]
+ print(i,len(node1_list))
+ return node_dict, node2_list
+
+def swap(a,b):
+ return b, a
+
+def mkdir(d):
+ path = d.split('/')
+ for i in range(len(path)):
+ d = '/'.join(path[:i+1])
+ if not os.path.exists(d):
+ os.mkdir(d)
+
+def csv_split(line, sc=','):
+ res = []
+ inside = 0
+ s = ''
+ for c in line:
+ if inside == 0 and c == sc:
+ res.append(s)
+ s = ''
+ else:
+ if c == '"':
+ inside = 1 - inside
+ s = s + c
+ res.append(s)
+ return res
diff --git a/src/Model_v2.png b/src/Model_v2.png
new file mode 100644
index 0000000..b07197c
Binary files /dev/null and b/src/Model_v2.png differ
diff --git a/src/UI_V2.png b/src/UI_V2.png
new file mode 100644
index 0000000..552a742
Binary files /dev/null and b/src/UI_V2.png differ
diff --git a/src/asratio_AmsterdamUMCdb.png b/src/asratio_AmsterdamUMCdb.png
new file mode 100644
index 0000000..6ae7607
Binary files /dev/null and b/src/asratio_AmsterdamUMCdb.png differ
diff --git a/src/asratio_MIMIC-III.png b/src/asratio_MIMIC-III.png
new file mode 100644
index 0000000..092988d
Binary files /dev/null and b/src/asratio_MIMIC-III.png differ
diff --git a/src/asratio_OSUWMC.png b/src/asratio_OSUWMC.png
new file mode 100644
index 0000000..aec3568
Binary files /dev/null and b/src/asratio_OSUWMC.png differ
diff --git a/src/auroc-uncertainty-as.png b/src/auroc-uncertainty-as.png
new file mode 100644
index 0000000..83e7019
Binary files /dev/null and b/src/auroc-uncertainty-as.png differ
diff --git a/src/auroc-uncertainty-obs.png b/src/auroc-uncertainty-obs.png
new file mode 100644
index 0000000..f5a51d7
Binary files /dev/null and b/src/auroc-uncertainty-obs.png differ
diff --git a/src/interaction_v3.png b/src/interaction_v3.png
new file mode 100644
index 0000000..6242d44
Binary files /dev/null and b/src/interaction_v3.png differ
diff --git a/src/setting.PNG b/src/setting.PNG
new file mode 100644
index 0000000..c55e386
Binary files /dev/null and b/src/setting.PNG differ