-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
98 lines (87 loc) · 3.43 KB
/
predict.py
File metadata and controls
98 lines (87 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
author: majiahao
date: 2025/11/20
for prediction
"""
# 加载指定路径下的模型
import torch
from torch import nn
from config.config import get_args
from utils.utils import get_dict
from dataset.dataset import CSGODataset
from model.model import CSGOpredictor
from torch.utils.data import DataLoader
from tqdm import tqdm
def hot_vec_to_label(hot_vec,inv_go_dict):
"""
hot_vec: 模型输出的热向量,形状为 (1, num_go_terms)
go_dict: GO术语到索引的映射字典
"""
# 找到热向量中值为1的索引
# predicted_indices = torch.nonzero(hot_vec, as_tuple=False).squeeze()
predicted_indices = []
for i in range(hot_vec.shape[0]):
if hot_vec[i] == 1:
predicted_indices.append(i)
# 将索引转换为GO术语
predicted_go_terms = [inv_go_dict[idx] for idx in predicted_indices]
return predicted_go_terms
def predict(model, dataloader, sl_dict,dom_dict,go_dict,device,model_path):
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
inv_go_dict = {v: k for k, v in go_dict.items()}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Predicting"):
uniprot_id = batch["uniprot_id"]
esm_feats = [feat.to(device) for feat in batch["embeddings"]]
sl_idx = batch["subcellular_labels"].to(device)
sl_pad = batch["subcellular_pad"].to(device)
dom_idx = batch["domain_labels"].to(device)
dom_pad = batch["domain_pad"].to(device)
def main():
args = get_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CSGOpredictor(args.hidden_size,
args.subcellular_labels,
args.domain_labels,
args.go_labels,
args.fusion_stage,
args.fusion_method)
dataset = CSGODataset(args.predict_data_path,args.cache_path,sl_dict,dom_dict,go_dict)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
pbar = tqdm(dataloader, desc=f"Predicting")
with torch.no_grad():
for batch in pbar:
uniprot_id = batch["uniprot_id"]
esm_feats = [feat.to(device) for feat in batch["embeddings"]]
sl_idx = batch["subcellular_labels"].to(device)
sl_pad = batch["subcellular_pad"].to(device)
dom_idx = batch["domain_labels"].to(device)
dom_pad = batch["domain_pad"].to(device)
# forward
logits = model(
esm_feats[2],
location_idx=sl_idx,
location_pad=sl_pad,
domain_idx=dom_idx,
domain_pad=dom_pad
) # [B, go_num]
# 这里是过sigmoid,将logits转换为概率
probs = torch.sigmoid(logits)
# 转换为热向量(根据阈值0.1)
hot_vecs = (probs > 0.9).float()
for hot_vec, uniprot_id,prob in zip(hot_vecs, uniprot_id,probs):
predicted_go_terms = hot_vec_to_label(hot_vec, inv_go_dict)
with open(args.predict_output_path, "a") as f:
for term in predicted_go_terms:
f.write(f"{uniprot_id}\t{term}\t{prob[go_dict[term]]:.4f}\n")
if __name__ == "__main__":
main()