-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_my_classifier.py
executable file
·63 lines (49 loc) · 2.39 KB
/
run_my_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import numpy as np
from ExtractAllData import extractWholeRecord
import sys
from ModelDefinition import Sleep_model_MultiTarget
import os
import gc
modelPaths = ['./TrainedModels/Model_Auxiliary1_stateDict.pkl', './TrainedModels/Model_Auxiliary2_stateDict.pkl',
'./TrainedModels/Model_Auxiliary3_stateDict.pkl', './TrainedModels/Model_Auxiliary4_stateDict.pkl']
testInPython = False
if testInPython:
recordList = ['tr03-1371']
cudaAvailable = True
else:
recordList = sys.argv[1:]
cudaAvailable = torch.cuda.is_available()
if __name__ == '__main__':
for record in recordList:
processedSignal = extractWholeRecord(recordName=str(record),
extractAnnotations=False,
dataPath='./',
arousalAnnotationPath=None,
apneaHypopneaAnnotationPath=None,
sleepWakeAnnotationPath=None,
dataInDirectory=False)
arousalPredictions = np.zeros(processedSignal.shape[0]*4,)
processedSignal = torch.Tensor(processedSignal).unsqueeze(0).permute(0, 2, 1)
for foldIndex in range(len(modelPaths)):
model = Sleep_model_MultiTarget()
model.load_state_dict(torch.load(modelPaths[foldIndex]))
model = model.eval()
if cudaAvailable:
model = model.cuda()
processedSignal = processedSignal.cuda()
with torch.no_grad():
tempArousalPredictions, tempApneaHypopneaPredictions, tempSleepWakePredictions = model(processedSignal)
tempArousalPredictions = torch.nn.functional.upsample(tempArousalPredictions, scale_factor=200, mode='linear')
tempArousalPredictions = tempArousalPredictions[0, 1, ::].detach().cpu().numpy()
arousalPredictions += tempArousalPredictions
# Clean up
del tempArousalPredictions
del tempApneaHypopneaPredictions
del tempSleepWakePredictions
del model
torch.cuda.empty_cache()
gc.collect()
arousalPredictions /= float(len(modelPaths))
output_file = os.path.basename(record) + '.vec'
np.savetxt(output_file, arousalPredictions, fmt='%.3f')