-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
94 lines (82 loc) · 3.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from rdkit.Chem import Draw
import numpy as np
from rdkit.Chem import AllChem
from rdkit import Chem
import torch
import os
import json
import pdb
def mol2array(mol):
img = Draw.MolToImage(mol, kekulize=False)
array = np.array(img)[:, :, 0:3]
return array
def check(smile):
smile = smile.split('.')
smile.sort(key = len)
try:
mol = Chem.MolFromSmiles(smile[-1], sanitize=False)
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True
except Exception:
return False
def mol2file(m, name):
AllChem.Compute2DCoords(m)
img = Draw.MolToImage(m)
Draw.MolToFile(m, os.path.join('./img', name))
def result2mol(args): # for threading
element, mask, bond, aroma, charge, reactant = args
# [L], [L], [L, 4], [l], [l]
mask = mask.ne(1)
cur_len = sum(mask.long())
l = element.shape[0]
mol = Chem.RWMol()
element = element.cpu().numpy().tolist()
charge = charge.cpu().numpy().tolist()
bond = bond.cpu().numpy().tolist()
# add atoms to mol and keep track of index
node_to_idx = {}
for i in range(l):
if mask[i] == False:
continue
a = Chem.Atom(element[i])
if not reactant is None and reactant[i]:
a.SetAtomMapNum(i+1)
molIdx = mol.AddAtom(a)
node_to_idx[i] = molIdx
# add bonds between adjacent atoms
for this in range(l):
if mask[this] == False:
continue
lst = bond[this]
for j in range(len(bond[0])):
other = bond[this][j]
# only traverse half the matrix
if other >= this or other in lst[0:j] or not this in bond[other]:
continue
if lst.count(other)==3 or bond[other].count(this) == 3:
bond_type = Chem.rdchem.BondType.TRIPLE
mol.AddBond(node_to_idx[this], node_to_idx[other], bond_type)
elif lst.count(other) == 2 or bond[other].count(this) == 2:
bond_type = Chem.rdchem.BondType.DOUBLE
mol.AddBond(node_to_idx[this], node_to_idx[other], bond_type)
else:
if aroma[this]==aroma[other] and aroma[this]>0:
bond_type = Chem.rdchem.BondType.AROMATIC
else:
bond_type = Chem.rdchem.BondType.SINGLE
mol.AddBond(node_to_idx[this], node_to_idx[other], bond_type)
for i, item in enumerate(charge):
if mask[i] == False:
continue
if not item == 0:
atom = mol.GetAtomWithIdx(node_to_idx[i])
atom.SetFormalCharge(item)
# Convert RWMol to Mol object
mol = mol.GetMol()
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ADJUSTHS)
smile = Chem.MolToSmiles(mol)
return mol, smile, check(smile)
def visualize(element, mask, bond, aroma, charge, reactant=None):
mol, smile, _ = result2mol((element, mask, bond, aroma, charge, reactant))
array = mol2array(mol)
return array, smile