-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathdatautils.py
137 lines (105 loc) · 4.53 KB
/
datautils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from torch.utils.data import Dataset, DataLoader
from mol_tree import MolTree
import numpy as np
from jtnn_enc import JTNNEncoder
from mpn import MPN
from jtmpn import JTMPN
import cPickle as pickle
import os, random
class PairTreeFolder(object):
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, y_assm=True, replicate=None):
self.data_folder = data_folder
self.data_files = [fn for fn in os.listdir(data_folder)]
self.batch_size = batch_size
self.vocab = vocab
self.num_workers = num_workers
self.y_assm = y_assm
self.shuffle = shuffle
if replicate is not None: #expand is int
self.data_files = self.data_files * replicate
def __iter__(self):
for fn in self.data_files:
fn = os.path.join(self.data_folder, fn)
with open(fn) as f:
data = pickle.load(f)
if self.shuffle:
random.shuffle(data) #shuffle data before batch
batches = [data[i : i + self.batch_size] for i in xrange(0, len(data), self.batch_size)]
if len(batches[-1]) < self.batch_size:
batches.pop()
dataset = PairTreeDataset(batches, self.vocab, self.y_assm)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
for b in dataloader:
yield b
del data, batches, dataset, dataloader
class MolTreeFolder(object):
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None):
self.data_folder = data_folder
self.data_files = [fn for fn in os.listdir(data_folder)]
self.batch_size = batch_size
self.vocab = vocab
self.num_workers = num_workers
self.shuffle = shuffle
self.assm = assm
if replicate is not None: #expand is int
self.data_files = self.data_files * replicate
def __iter__(self):
for fn in self.data_files:
fn = os.path.join(self.data_folder, fn)
with open(fn) as f:
data = pickle.load(f)
if self.shuffle:
random.shuffle(data) #shuffle data before batch
batches = [data[i : i + self.batch_size] for i in xrange(0, len(data), self.batch_size)]
if len(batches[-1]) < self.batch_size:
batches.pop()
dataset = MolTreeDataset(batches, self.vocab, self.assm)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
for b in dataloader:
yield b
del data, batches, dataset, dataloader
class PairTreeDataset(Dataset):
def __init__(self, data, vocab, y_assm):
self.data = data
self.vocab = vocab
self.y_assm = y_assm
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
batch0, batch1 = zip(*self.data[idx])
return tensorize(batch0, self.vocab, assm=False), tensorize(batch1, self.vocab, assm=self.y_assm)
class MolTreeDataset(Dataset):
def __init__(self, data, vocab, assm=True):
self.data = data
self.vocab = vocab
self.assm = assm
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return tensorize(self.data[idx], self.vocab, assm=self.assm)
def tensorize(tree_batch, vocab, assm=True):
set_batch_nodeID(tree_batch, vocab)
smiles_batch = [tree.smiles for tree in tree_batch]
jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
mpn_holder = MPN.tensorize(smiles_batch)
if assm is False:
return tree_batch, jtenc_holder, mpn_holder
cands = []
batch_idx = []
for i,mol_tree in enumerate(tree_batch):
for node in mol_tree.nodes:
#Leaf node's attachment is determined by neighboring node's attachment
if node.is_leaf or len(node.cands) == 1: continue
cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
batch_idx.extend([i] * len(node.cands))
jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
batch_idx = torch.LongTensor(batch_idx)
return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
def set_batch_nodeID(mol_batch, vocab):
tot = 0
for mol_tree in mol_batch:
for node in mol_tree.nodes:
node.idx = tot
node.wid = vocab.get_index(node.smiles)
tot += 1