-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathdiff_vae.py
224 lines (176 loc) · 9.99 KB
/
diff_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import torch
import torch.nn as nn
import torch.nn.functional as F
from mol_tree import Vocab, MolTree
from nnutils import create_var, flatten_tensor, avg_pool
from jtnn_enc import JTNNEncoder
from jtnn_dec import JTNNDecoder
from mpn import MPN
from jtmpn import JTMPN
from chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols
import rdkit
import rdkit.Chem as Chem
import copy, math
class DiffVAE(nn.Module):
def __init__(self, vocab, args):
super(DiffVAE, self).__init__()
self.vocab = vocab
self.hidden_size = hidden_size = args.hidden_size
self.rand_size = rand_size = args.rand_size
self.jtmpn = JTMPN(hidden_size, args.depthG)
self.mpn = MPN(hidden_size, args.depthG)
if args.share_embedding:
self.embedding = nn.Embedding(vocab.size(), hidden_size)
self.jtnn = JTNNEncoder(hidden_size, args.depthT, self.embedding)
self.decoder = JTNNDecoder(vocab, hidden_size, self.embedding, args.use_molatt)
else:
self.jtnn = JTNNEncoder(hidden_size, args.depthT, nn.Embedding(vocab.size(), hidden_size))
self.decoder = JTNNDecoder(vocab, hidden_size, nn.Embedding(vocab.size(), hidden_size), args.use_molatt)
self.A_assm = nn.Linear(hidden_size, hidden_size, bias=False)
self.assm_loss = nn.CrossEntropyLoss(size_average=False)
self.T_mean = nn.Linear(hidden_size, rand_size / 2)
self.T_var = nn.Linear(hidden_size, rand_size / 2)
self.G_mean = nn.Linear(hidden_size, rand_size / 2)
self.G_var = nn.Linear(hidden_size, rand_size / 2)
self.B_t = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU())
self.B_g = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU())
def encode(self, jtenc_holder, mpn_holder):
tree_vecs, tree_mess = self.jtnn(*jtenc_holder)
mol_vecs = self.mpn(*mpn_holder)
return tree_vecs, tree_mess, mol_vecs
def fuse_noise(self, tree_vecs, mol_vecs):
tree_eps = create_var( torch.randn(tree_vecs.size(0), 1, self.rand_size / 2) )
tree_eps = tree_eps.expand(-1, tree_vecs.size(1), -1)
mol_eps = create_var( torch.randn(mol_vecs.size(0), 1, self.rand_size / 2) )
mol_eps = mol_eps.expand(-1, mol_vecs.size(1), -1)
tree_vecs = torch.cat([tree_vecs,tree_eps], dim=-1)
mol_vecs = torch.cat([mol_vecs,mol_eps], dim=-1)
return self.B_t(tree_vecs), self.B_g(mol_vecs)
def fuse_pair(self, x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs, jtenc_scope, mpn_scope):
diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1)
size = create_var(torch.Tensor([le for _,le in jtenc_scope]))
diff_tree_vecs = diff_tree_vecs / size.unsqueeze(-1)
diff_mol_vecs = y_mol_vecs.sum(dim=1) - x_mol_vecs.sum(dim=1)
size = create_var(torch.Tensor([le for _,le in mpn_scope]))
diff_mol_vecs = diff_mol_vecs / size.unsqueeze(-1)
diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var)
diff_mol_vecs, mol_kl = self.rsample(diff_mol_vecs, self.G_mean, self.G_var)
diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1)
diff_mol_vecs = diff_mol_vecs.unsqueeze(1).expand(-1, x_mol_vecs.size(1), -1)
x_tree_vecs = torch.cat([x_tree_vecs,diff_tree_vecs], dim=-1)
x_mol_vecs = torch.cat([x_mol_vecs,diff_mol_vecs], dim=-1)
return self.B_t(x_tree_vecs), self.B_g(x_mol_vecs), tree_kl + mol_kl
def rsample(self, z_vecs, W_mean, W_var):
z_mean = W_mean(z_vecs)
z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
kl_loss = -0.5 * torch.mean(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var))
epsilon = create_var(torch.randn_like(z_mean))
z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
return z_vecs, kl_loss
def forward(self, x_batch, y_batch, beta):
x_batch, x_jtenc_holder, x_mpn_holder = x_batch
y_batch, y_jtenc_holder, y_mpn_holder, y_jtmpn_holder = y_batch
x_tree_vecs, _, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder)
y_tree_vecs, y_tree_mess, y_mol_vecs = self.encode(y_jtenc_holder, y_mpn_holder)
x_tree_vecs, x_mol_vecs, kl_div = self.fuse_pair(x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs, y_jtenc_holder[-1], y_mpn_holder[-1])
word_loss, topo_loss, word_acc, topo_acc = self.decoder(y_batch, x_tree_vecs, x_mol_vecs)
assm_loss, assm_acc = self.assm(y_batch, y_jtmpn_holder, x_mol_vecs, y_tree_mess)
return word_loss + topo_loss + assm_loss + beta * kl_div, kl_div.item(), word_acc, topo_acc, assm_acc
def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, y_tree_mess):
jtmpn_holder,batch_idx = jtmpn_holder
fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
batch_idx = create_var(batch_idx)
cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess)
x_mol_vecs = x_mol_vecs.sum(dim=1) #average pooling?
x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
scores = torch.bmm(
x_mol_vecs.unsqueeze(1),
cand_vecs.unsqueeze(-1)
).squeeze()
cnt,tot,acc = 0,0,0
all_loss = []
for i,mol_tree in enumerate(mol_batch):
comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
cnt += len(comp_nodes)
for node in comp_nodes:
label = node.cands.index(node.label)
ncand = len(node.cands)
cur_score = scores.narrow(0, tot, ncand)
tot += ncand
if cur_score.data[label] >= cur_score.max().item():
acc += 1
label = create_var(torch.LongTensor([label]))
all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
all_loss = sum(all_loss) / len(mol_batch)
return all_loss, acc * 1.0 / cnt
def decode(self, x_tree_vecs, x_mol_vecs):
#currently do not support batch decoding
assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1
pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, x_mol_vecs)
if len(pred_nodes) == 0: return None
elif len(pred_nodes) == 1: return pred_root.smiles
#Mark nid & is_leaf & atommap
for i,node in enumerate(pred_nodes):
node.nid = i + 1
node.is_leaf = (len(node.neighbors) == 1)
if len(node.neighbors) > 1:
set_atommap(node.mol, node.nid)
scope = [(0, len(pred_nodes))]
jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
_,tree_mess = self.jtnn(*jtenc_holder)
tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict
x_mol_vec_pooled = x_mol_vecs.sum(dim=1) #average pooling?
x_mol_vec_pooled = self.A_assm(x_mol_vec_pooled).squeeze() #bilinear
cur_mol = copy_edit_mol(pred_root.mol)
global_amap = [{}] + [{} for node in pred_nodes]
global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
cur_mol = self.dfs_assemble(tree_mess, x_mol_vec_pooled, pred_nodes, cur_mol, global_amap, [], pred_root, None)
if cur_mol is None:
return None
cur_mol = cur_mol.GetMol()
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
def dfs_assemble(self, y_tree_mess, x_mol_vec_pooled, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node):
fa_nid = fa_node.nid if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
cand_smiles,cand_amap = zip(*cands)
cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]
jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])
scores = torch.mv(cand_vecs, x_mol_vec_pooled)
_,cand_idx = torch.sort(scores, descending=True)
backup_mol = Chem.RWMol(cur_mol)
#for i in xrange(cand_idx.numel()):
for i in xrange( min(cand_idx.numel(), 5) ):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)
for nei_id,ctr_atom,nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]
cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
new_mol = cur_mol.GetMol()
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
if new_mol is None: continue
result = True
for nei_node in children:
if nei_node.is_leaf: continue
cur_mol = self.dfs_assemble(y_tree_mess, x_mol_vec_pooled, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node)
if cur_mol is None:
result = False
break
if result: return cur_mol
return None