-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDataSet.py
27 lines (21 loc) · 973 Bytes
/
DataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import numpy as np
import torch.utils.data as Data
class CMAPSS_Dataset_train(Data.Dataset):
def __init__(self, src_seq, trg_seq, RUL_label):
self.src_seq = (torch.from_numpy(src_seq)).float()
self.trg_seq = (torch.from_numpy(trg_seq)).float()
self.RUL_label = torch.FloatTensor(RUL_label)
def __getitem__(self, idx):
return self.src_seq[idx], self.trg_seq[idx], self.RUL_label[idx]
def __len__(self):
return len(self.RUL_label)
class CMAPSS_Dataset_valid_or_test(Data.Dataset):
def __init__(self, src_seq, RUL_label):
self.src_seq = (torch.from_numpy(src_seq)).float()
self.trg_seq = torch.zeros_like(self.src_seq)
self.RUL_label = torch.FloatTensor(RUL_label)
def __getitem__(self, idx):
return self.src_seq[idx], self.trg_seq[idx], self.RUL_label[idx]
def __len__(self):
return len(self.RUL_label)