-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_raw.py
executable file
·124 lines (110 loc) · 4.66 KB
/
data_raw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import random
import pandas as pd
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
import torch.nn.functional as F
def load_image(image_path, mean, std, threshold = [-1200, 600]):
image = nib.load(image_path).get_fdata()#.astype(np.int32)
np.clip(image, threshold[0], threshold[1], out=image)
np.subtract(image, mean, out = image)
np.divide(image, std, out = image)
image = image.transpose(2, 1, 0)
return image
def load_image_norm(image_path, threshold = [-1200, 600]):
image = nib.load(image_path).get_fdata() #.astype(np.int32)
np.clip(image,threshold[0],threshold[1],out=image)
image = (image-threshold[0])/(threshold[1] - threshold[0])
image = image.transpose(2, 1, 0)
return image
class TrainDataset(Dataset):
def __init__(self, data_dir, train_csv, label_csv):
self.data_dir = data_dir
train_df = pd.read_csv(train_csv)
self.names_train = train_df["name"] #["B19_PA11_SE1"]#
self.labels_train_df = pd.read_csv(label_csv, index_col=0)
self.mean = -604.2288900583559
self.std = 489.42172740885655
def __getitem__(self, item):
margin =8
name_train = self.names_train[item]
label_train = self.labels_train_df.at[name_train, "four_label"]
path_train = self.data_dir + name_train + ".nii.gz"
# image_train = nib.load(path_train).get_fdata().astype(np.int32).transpose(2, 1, 0)
image_train = load_image(path_train, self.mean, self.std)
z_train, h_train, w_train = image_train.shape
image_train=torch.from_numpy(image_train).float()
index_list=[]
if z_train<=80:
if z_train <= 16:
start = 0
else:
start = random.randrange(0,z_train-16)
for i in range(margin * 2):
index_list.append(start+i*1)
elif z_train<=160:
start=random.randrange(10,z_train-60)
for i in range(margin * 2):
index_list.append(start+i*2)#5)
else:
start=random.randrange(20,z_train-130)
for i in range(margin * 2):
index_list.append(start+i*5)#10)
image_train_crop=[]
for index in index_list:
if z_train < margin*2:
left_pad = (margin * 2 - z_train)//2
right_pad = margin * 2 - left_pad - z_train
pad = (0, 0, 0, 0, left_pad, right_pad)
image_train = F.pad(image_train, pad, "constant")
image_train_crop.append(image_train[index,:,:])
image_train_crop=torch.stack(image_train_crop,0).float()
return image_train_crop, label_train, name_train
def __len__(self):
return len(self.names_train)
class TestDataset(Dataset):
def __init__(self, data_dir, test_csv, label_csv):
self.data_dir = data_dir
test_df = pd.read_csv(test_csv)
self.names_test = test_df["name"]
self.labels_test_df = pd.read_csv(label_csv, index_col=0)
self.mean = -604.2288900583559
self.std = 489.42172740885655
def __getitem__(self, item):
margin = 8
name_test = self.names_test[item]
label_test = self.labels_test_df.at[name_test, "four_label"]
patient_id = self.labels_test_df.at[name_test, "patient_id"]
path_test = self.data_dir + name_test + ".nii.gz"
image_test = load_image(path_test, self.mean, self.std)
z_test, h_test, w_test = image_test.shape
image_test=torch.from_numpy(image_test).float()
index_list=[]
if z_test<=80:
if z_test <= margin*2:
start = 0
else:
start = random.randrange(0,z_test-margin*2)
for i in range(margin * 2):
index_list.append(start+i*1)
elif z_test<=160:
start=random.randrange(10,z_test-60)
for i in range(margin * 2):
index_list.append(start+i*2)#5)
else:
start=random.randrange(30,z_test-120)
for i in range(margin * 2):
index_list.append(start+i*5)#10)
image_test_crop=[]
for index in index_list:
if z_test < margin*2:
left_pad = (margin * 2 - z_test)//2
right_pad = margin * 2 - left_pad - z_test
pad = (0, 0, 0, 0, left_pad, right_pad)
image_test = F.pad(image_test, pad, "constant")
image_test_crop.append(image_test[index,:,:])
image_test_crop=torch.stack(image_test_crop,0).float()
return image_test_crop, label_test, name_test, patient_id
def __len__(self):
return len(self.names_test)