-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
43 lines (36 loc) · 1.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from simpletransformers.classification import ClassificationModel
from rdkit import Chem
import argparse
import numpy as np
def seperatior_checker(rxn):
if rxn.find('>>')!=-1:
reactants = rxn.split('>>')[0]
product = rxn.split('>>')[1]
if Chem.MolFromSmiles(product)!= None:
reactants_sep = reactants.split('.')
for smiles in reactants_sep:
if Chem.MolFromSmiles(smiles)!=None:
continue
else:
raise ValueError(f'Invalid Smiles in reactants {smiles}')
else:
raise ValueError(f'Invalid Smiles in product {product}')
return True
def parse_args():
parser = argparse.ArgumentParser(description='Run Buchwald Hartwig Yield prediction from command line')
parser.add_argument('-s', '--reaction', default=None, type = str,
help='Reaction input for yield predictions')
parser.add_argument('-n', '--name', default='test_reaction', help='The name of the molecule')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
if bool(args.reaction) == True:
if seperatior_checker(args.reaction) == True:
model = ClassificationModel('roberta', 'Parsa/Buchwald-Hartwig-Yield-prediction',use_cuda=False, num_labels=1, args={
"regression": True})
pred, _ = model.predict([args.reaction])
print(f'{round(abs(np.clip(pred, 0,1)*100, 2)} %')
else:
print('Invalid Reaction')
else:
print('Empty input')