Skip to content

Commit 3d8c55a

Browse files
committed
Include top-k recall metrics in training pipeline, and bump version
1 parent e89d946 commit 3d8c55a

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

project/utils/deepinteract_modules.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from project.utils.deepinteract_constants import FEATURE_INDICES, RESIDUE_COUNT_LIMIT, NODE_COUNT_LIMIT
2020
from project.utils.deepinteract_utils import construct_interact_tensor, glorot_orthogonal, get_geo_feats_from_edges, \
2121
construct_subsequenced_interact_tensors, insert_interact_tensor_logits, \
22-
remove_padding, remove_subsequenced_input_padding, calculate_top_k_prec, extract_object
22+
remove_padding, remove_subsequenced_input_padding, calculate_top_k_prec, calculate_top_k_recall, extract_object
2323
from project.utils.graph_utils import src_dot_dst, scaling, imp_exp_attn, out_edge_features, exp
2424
from project.utils.vision_modules import DeepLabV3Plus
2525

@@ -1926,17 +1926,17 @@ def validation_step(self, batch, batch_idx):
19261926

19271927
# Calculate top-k metrics
19281928
calculating_l_by_n_metrics = True
1929-
# Log only first 50 validation top-k precisions to limit algorithmic complexity due to sorting (if requested)
1929+
# Log only first 50 validation top-k metrics to limit algorithmic complexity due to sorting (if requested)
19301930
# calculating_l_by_n_metrics = batch_idx in [i for i in range(50)]
19311931
if calculating_l_by_n_metrics:
19321932
l = graph1.num_nodes() + graph2.num_nodes()
19331933
sorted_pred_indices = torch.argsort(preds[:, 1], descending=True)
19341934
top_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=10)
1935-
top_25_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=25)
1936-
top_50_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=50) if l > 50 else 0.0 # Catch short seq.
19371935
top_l_by_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 10))
19381936
top_l_by_5_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 5))
1939-
top_l_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=l)
1937+
top_l_recall = calculate_top_k_recall(sorted_pred_indices, labels, k=l)
1938+
top_l_by_2_recall = calculate_top_k_recall(sorted_pred_indices, labels, k=(l // 2))
1939+
top_l_by_5_recall = calculate_top_k_recall(sorted_pred_indices, labels, k=(l // 5))
19401940

19411941
# Calculate the protein interface prediction (PICP) loss along with additional PIP metrics
19421942
loss = self.loss_fn(sampled_logits, labels) # Calculate loss of a single complex
@@ -1951,11 +1951,11 @@ def validation_step(self, batch, batch_idx):
19511951
self.log(f'val_ce', loss, sync_dist=True)
19521952
if calculating_l_by_n_metrics:
19531953
self.log('val_top_10_prec', top_10_prec, sync_dist=True)
1954-
self.log('val_top_25_prec', top_25_prec, sync_dist=True)
1955-
self.log('val_top_50_prec', top_50_prec, sync_dist=True)
19561954
self.log('val_top_l_by_10_prec', top_l_by_10_prec, sync_dist=True)
19571955
self.log('val_top_l_by_5_prec', top_l_by_5_prec, sync_dist=True)
1958-
self.log('val_top_l_prec', top_l_prec, sync_dist=True)
1956+
self.log('val_top_l_recall', top_l_recall, sync_dist=True)
1957+
self.log('val_top_l_by_2_recall', top_l_by_2_recall, sync_dist=True)
1958+
self.log('val_top_l_by_5_recall', top_l_by_5_recall, sync_dist=True)
19591959

19601960
return {
19611961
'loss': loss,
@@ -2033,11 +2033,11 @@ def test_step(self, batch, batch_idx):
20332033
l = min(graph1.num_nodes(), graph2.num_nodes()) # Use the smallest length of the two chains as our denominator
20342034
sorted_pred_indices = torch.argsort(preds[:, 1], descending=True)
20352035
top_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=10)
2036-
top_25_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=25)
2037-
top_50_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=50) if l > 50 else 0.0 # Catch short seq.
20382036
top_l_by_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 10))
20392037
top_l_by_5_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 5))
2040-
top_l_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=l)
2038+
top_l_recall = calculate_top_k_recall(sorted_pred_indices, labels, k=l)
2039+
top_l_by_2_recall = calculate_top_k_recall(sorted_pred_indices, labels, k=(l // 2))
2040+
top_l_by_5_recall = calculate_top_k_recall(sorted_pred_indices, labels, k=(l // 5))
20412041

20422042
# Calculate the protein interface prediction (PICP) loss along with additional PIP metrics
20432043
loss = self.loss_fn(sampled_logits, labels) # Calculate loss of a single complex
@@ -2062,11 +2062,11 @@ def test_step(self, batch, batch_idx):
20622062
# Log test step metric(s)
20632063
self.log(f'test_ce', loss, sync_dist=True)
20642064
self.log('test_top_10_prec', top_10_prec, sync_dist=True)
2065-
self.log('test_top_25_prec', top_25_prec, sync_dist=True)
2066-
self.log('test_top_50_prec', top_50_prec, sync_dist=True)
20672065
self.log('test_top_l_by_10_prec', top_l_by_10_prec, sync_dist=True)
20682066
self.log('test_top_l_by_5_prec', top_l_by_5_prec, sync_dist=True)
2069-
self.log('test_top_l_prec', top_l_prec, sync_dist=True)
2067+
self.log('test_top_l_recall', top_l_recall, sync_dist=True)
2068+
self.log('test_top_l_by_2_recall', top_l_by_2_recall, sync_dist=True)
2069+
self.log('test_top_l_by_5_recall', top_l_by_5_recall, sync_dist=True)
20702070

20712071
return {
20722072
'loss': loss,
@@ -2082,6 +2082,9 @@ def test_step(self, batch, batch_idx):
20822082
'top_10_prec': top_10_prec,
20832083
'top_l_by_10_prec': top_l_by_10_prec,
20842084
'top_l_by_5_prec': top_l_by_5_prec,
2085+
'top_l_recall': top_l_recall,
2086+
'top_l_by_2_recall': top_l_by_2_recall,
2087+
'top_l_by_5_recall': top_l_by_5_recall,
20852088
'target': filepaths[0].split(os.sep)[-1][:4]
20862089
}
20872090

@@ -2112,17 +2115,20 @@ def test_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT):
21122115
test_preds_rounded = [(output_dict['test_preds_rounded']) for output_dict in outputs]
21132116
test_labels = [output_dict['test_labels'] for output_dict in outputs]
21142117

2115-
# Write out test top-k precision results to CSV
2116-
prec_data = {
2118+
# Write out test top-k metric results to CSV
2119+
metrics_data = {
21172120
'top_10_prec': [extract_object(output_dict['top_10_prec']) for output_dict in outputs],
21182121
'top_l_by_10_prec': [extract_object(output_dict['top_l_by_10_prec']) for output_dict in outputs],
21192122
'top_l_by_5_prec': [extract_object(output_dict['top_l_by_5_prec']) for output_dict in outputs],
2123+
'top_l_recall': [extract_object(output_dict['top_l_recall']) for output_dict in outputs],
2124+
'top_l_by_2_recall': [extract_object(output_dict['top_l_by_2_recall']) for output_dict in outputs],
2125+
'top_l_by_5_recall': [extract_object(output_dict['top_l_by_5_recall']) for output_dict in outputs],
21202126
'target': [extract_object(output_dict['target']) for output_dict in outputs],
21212127
}
2122-
prec_df = pd.DataFrame(data=prec_data)
2123-
prec_df_name_prefix = 'casp_capri' if self.testing_with_casp_capri else 'dips_plus_test'
2124-
prec_df_name = prec_df_name_prefix + '_top_prec.csv'
2125-
prec_df.to_csv(prec_df_name)
2128+
metrics_df = pd.DataFrame(data=metrics_data)
2129+
metrics_df_name_prefix = 'casp_capri' if self.testing_with_casp_capri else 'dips_plus_test'
2130+
metrics_df_name = metrics_df_name_prefix + '_top_metrics.csv'
2131+
metrics_df.to_csv(metrics_df_name)
21262132

21272133
if not self.testing_with_casp_capri: # Testing with DIPS-Plus
21282134
# Filter out all but the first 55 test predictions and labels to reduce storage requirements

project/utils/deepinteract_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@
2929
from Bio.Seq import Seq
3030
from Bio.SeqRecord import SeqRecord
3131
from biopandas.pdb import PandasPdb
32-
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
33-
3432
from project.utils.deepinteract_constants import FEAT_COLS, ALLOWABLE_FEATS, D3TO1
3533
from project.utils.dips_plus_utils import postprocess_pruned_pairs, impute_postprocessed_missing_feature_values
3634
from project.utils.graph_utils import prot_df_to_dgl_graph_feats
3735
from project.utils.protein_feature_utils import GeometricProteinFeatures
36+
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
3837

3938
try:
4039
from types import SliceType
@@ -328,7 +327,7 @@ def substitute_missing_atoms(struct_df: pd.DataFrame, all_atom_struct_df: pd.Dat
328327
raise NotImplementedError('Error: A missing atom was found, and it is not possible to process it.')
329328

330329
# Choose a replacement for the missing atom
331-
available_atom_keys = set(atom_names) - {missing_atom_key}
330+
available_atom_keys = set(atom_names) - {missing_atom_key, 'CA'} # Disallow CA atoms from being a sub
332331
replacement_atom_name = available_atom_keys.pop() # Choose the first available atom as the substitute
333332
replacement_atom = ca_atom_support_atoms[ca_atom_support_atoms['atom_name'] == replacement_atom_name]
334333
logging.info(f'Found a missing {missing_atom_key} atom for row number {ca_atom_idx} -'
@@ -973,6 +972,17 @@ def calculate_top_k_prec(sorted_pred_indices: torch.Tensor, labels: torch.Tensor
973972
return prec
974973

975974

975+
def calculate_top_k_recall(sorted_pred_indices: torch.Tensor, labels: torch.Tensor, k: int):
976+
"""Calculate the top-k interaction recall."""
977+
num_interactions_to_score = k
978+
selected_pred_indices = sorted_pred_indices[:num_interactions_to_score]
979+
true_labels = labels[selected_pred_indices]
980+
num_correct = torch.sum(true_labels).item()
981+
num_pos_labels = torch.sum(labels).item()
982+
recall = num_correct / num_pos_labels
983+
return recall
984+
985+
976986
def extract_object(obj: any):
977987
"""If incoming object is of type torch.Tensor, convert it to a NumPy array. If it is a scalar, simply return it."""
978988
return obj.cpu().numpy() if type(obj) == torch.Tensor else obj

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
setup(
66
name='DeepInteract',
7-
version='1.0.7',
7+
version='1.0.8',
88
description='A geometric deep learning pipeline for predicting protein interface contacts.',
99
author='Alex Morehead',
1010
author_email='[email protected]',

0 commit comments

Comments
 (0)