-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdataset.py
86 lines (65 loc) · 2.89 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# import pandas as pd
import numpy as np
import torch
import torch.utils.data as data
import random
from sklearn.utils import check_random_state
##-----------------------------------------------------------------------------------------------------------
class TreeDataset(data.Dataset):
'''
Subclass of the data.Dataset class. We override __len__, that provides the size of the dataset, and __getitem__, supporting integer indexing in range from 0 to len(self) exclusive.
Args:
data: dataset
labels: labels of each element of the dataset
transform: function that we want to apply to the data. For trees, this will be that function that creates the training batches
batch_size: size of the training batches
features: Number of features in each node
'''
##----
def __init__(self,data=None,labels=None,shuffle=True,transform=None,batch_size=None,features=None):
self.data=data
self.labels=labels
self.transform=transform
self.batch_size=batch_size
self.features=features
if shuffle:
indices = check_random_state(seed=None).permutation(len(self.data))
# print('self.data=',self.data[0]['tree'])
# print('self.labels=',self.labels)
self.data=self.data[indices]
self.labels=self.labels[indices]
# print('-+-+'*20)
# print('self.data=',self.data[0]['tree'])
# print('self.labels=',self.labels)
##----
# Override __getitem__
def __getitem__(self,index):
if self.transform is not None:
levels, children, n_inners, contents, n_level= self.transform(self.data[index*self.batch_size:(index+1)*self.batch_size],self.features)
# Shift to np arrays
levels = np.asarray(levels)
children = np.asarray(children)
n_inners = np.asarray(n_inners)
contents = np.asarray(contents)
n_level = np.asarray(n_level)
labels= np.asarray(self.labels[index*self.batch_size:(index+1)*self.batch_size])
return levels, children, n_inners, contents, n_level, labels
##----
# Override __len__, that provides the size of the dataset
def __len__(self):
return len(self.data)
##-----------------------------------------------------------------------------------------------------------
def customized_collate(batch):
""""
default_collate contains definitions of the methods used by the _DataLoaderIter workers to collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
Here we define customized_collate that returns the elements of each batch tuple shifted to pytorch tensors.
"""
levels=torch.LongTensor(batch[0][0])
children=torch.LongTensor(batch[0][1])
n_inners=torch.LongTensor(batch[0][2])
contents = torch.FloatTensor(batch[0][3])
n_level=torch.LongTensor(batch[0][4])
labels= torch.LongTensor(batch[0][5])
return levels, children, n_inners, contents, n_level, labels