-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinference.py
More file actions
119 lines (99 loc) · 4.09 KB
/
inference.py
File metadata and controls
119 lines (99 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import re
import torch
import argparse
from Bio import PDB
from omegaconf import OmegaConf
from src.metoken_model import MeToken_Model
from src.datasets.featurizer import featurize
from src.constant import PTMtype_list
def parse_arguments():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="PTM prediction from PDB files.")
parser.add_argument(
'--pdb_file_path', type=str, required=True,
help="Path to the input PDB file."
)
parser.add_argument(
'--checkpoint_path', type=str, default='pretrained_model/checkpoint.ckpt',
help="Path to the model checkpoint file."
)
parser.add_argument(
'--output_json_path', type=str, default='output/predict.json',
help="Path to save the prediction results in JSON format. Default is 'output/predict.json'."
)
parser.add_argument(
'--output_hdf5_path', type=str, default='output/predict.hdf5',
help="Path to save the prediction results in HDF5 format. Default is 'output/predict.hdf5'."
)
parser.add_argument(
'--query_indices', type=int, nargs='+', default=[31],
help="List of residue indices for PTM prediction (Start from 1)."
)
return parser.parse_args()
def extract_pdb_id(file_path):
"""Extract PDB ID from the file path."""
match = re.search(r'([^/]+)(?=\.pdb)', file_path)
if match:
return match.group(0)
else:
raise ValueError(f"Could not extract PDB ID from file path: {file_path}")
def get_seq_str(pdb_file_path, chain_id='A'):
parser = PDB.PDBParser(QUIET=True)
if pdb_file_path.endswith('.pdb'):
pdb_id = extract_pdb_id(pdb_file_path)
structure = parser.get_structure(pdb_id, pdb_file_path)
seq = ''
coords_chain_A = {'N_chain_A': [], 'C_chain_A': [], 'CA_chain_A': [], 'O_chain_A': []}
for model in structure:
for chain in model:
if chain.id == chain_id:
for residue in chain:
if PDB.is_aa(residue):
seq += PDB.Polypeptide.three_to_one(residue.resname)
for atom in residue:
if atom.id == 'N':
coords_chain_A['N_chain_A'].append(atom.coord.tolist())
elif atom.id == 'C':
coords_chain_A['C_chain_A'].append(atom.coord.tolist())
elif atom.id == 'CA':
coords_chain_A['CA_chain_A'].append(atom.coord.tolist())
elif atom.id == 'O':
coords_chain_A['O_chain_A'].append(atom.coord.tolist())
protein_data = {
"id": pdb_id,
"seq": seq,
"coords_chain_A": coords_chain_A
}
else:
raise "The PDB file path is invalid."
return protein_data
def apply_ptm_indices(protein_data, ptm_indices):
seq_length = len(protein_data["seq"])
if -1 in ptm_indices:
ptm = [1] * seq_length
else:
ptm = [0] * seq_length
for index in ptm_indices:
if 0 <= index < seq_length:
ptm[index] = 1
protein_data["ptm"] = ptm
return protein_data
def main():
args = parse_arguments()
try:
protein_data = get_seq_str(args.pdb_file_path)
args.query_indices = [index-1 for index in args.query_indices]
protein_data = apply_ptm_indices(protein_data, args.query_indices)
checkpoint = torch.load(args.checkpoint_path)
params = OmegaConf.load('configs/MeToken.yaml')
model = MeToken_Model(params)
model.load_state_dict(checkpoint)
data = featurize([protein_data])
result = model(data)
preds = result['log_probs'].argmax(dim=-1).cpu()
for pos in args.query_indices:
print(f'PTM type at the position {pos+1} is {PTMtype_list[preds[pos]]}.\n')
except (FileNotFoundError, ValueError) as e:
print(f"Error: {e}")
if __name__ == '__main__':
main()