19
19
from project .utils .deepinteract_constants import FEATURE_INDICES , RESIDUE_COUNT_LIMIT , NODE_COUNT_LIMIT
20
20
from project .utils .deepinteract_utils import construct_interact_tensor , glorot_orthogonal , get_geo_feats_from_edges , \
21
21
construct_subsequenced_interact_tensors , insert_interact_tensor_logits , \
22
- remove_padding , remove_subsequenced_input_padding , calculate_top_k_prec , extract_object
22
+ remove_padding , remove_subsequenced_input_padding , calculate_top_k_prec , calculate_top_k_recall , extract_object
23
23
from project .utils .graph_utils import src_dot_dst , scaling , imp_exp_attn , out_edge_features , exp
24
24
from project .utils .vision_modules import DeepLabV3Plus
25
25
@@ -1926,17 +1926,17 @@ def validation_step(self, batch, batch_idx):
1926
1926
1927
1927
# Calculate top-k metrics
1928
1928
calculating_l_by_n_metrics = True
1929
- # Log only first 50 validation top-k precisions to limit algorithmic complexity due to sorting (if requested)
1929
+ # Log only first 50 validation top-k metrics to limit algorithmic complexity due to sorting (if requested)
1930
1930
# calculating_l_by_n_metrics = batch_idx in [i for i in range(50)]
1931
1931
if calculating_l_by_n_metrics :
1932
1932
l = graph1 .num_nodes () + graph2 .num_nodes ()
1933
1933
sorted_pred_indices = torch .argsort (preds [:, 1 ], descending = True )
1934
1934
top_10_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = 10 )
1935
- top_25_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = 25 )
1936
- top_50_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = 50 ) if l > 50 else 0.0 # Catch short seq.
1937
1935
top_l_by_10_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = (l // 10 ))
1938
1936
top_l_by_5_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = (l // 5 ))
1939
- top_l_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = l )
1937
+ top_l_recall = calculate_top_k_recall (sorted_pred_indices , labels , k = l )
1938
+ top_l_by_2_recall = calculate_top_k_recall (sorted_pred_indices , labels , k = (l // 2 ))
1939
+ top_l_by_5_recall = calculate_top_k_recall (sorted_pred_indices , labels , k = (l // 5 ))
1940
1940
1941
1941
# Calculate the protein interface prediction (PICP) loss along with additional PIP metrics
1942
1942
loss = self .loss_fn (sampled_logits , labels ) # Calculate loss of a single complex
@@ -1951,11 +1951,11 @@ def validation_step(self, batch, batch_idx):
1951
1951
self .log (f'val_ce' , loss , sync_dist = True )
1952
1952
if calculating_l_by_n_metrics :
1953
1953
self .log ('val_top_10_prec' , top_10_prec , sync_dist = True )
1954
- self .log ('val_top_25_prec' , top_25_prec , sync_dist = True )
1955
- self .log ('val_top_50_prec' , top_50_prec , sync_dist = True )
1956
1954
self .log ('val_top_l_by_10_prec' , top_l_by_10_prec , sync_dist = True )
1957
1955
self .log ('val_top_l_by_5_prec' , top_l_by_5_prec , sync_dist = True )
1958
- self .log ('val_top_l_prec' , top_l_prec , sync_dist = True )
1956
+ self .log ('val_top_l_recall' , top_l_recall , sync_dist = True )
1957
+ self .log ('val_top_l_by_2_recall' , top_l_by_2_recall , sync_dist = True )
1958
+ self .log ('val_top_l_by_5_recall' , top_l_by_5_recall , sync_dist = True )
1959
1959
1960
1960
return {
1961
1961
'loss' : loss ,
@@ -2033,11 +2033,11 @@ def test_step(self, batch, batch_idx):
2033
2033
l = min (graph1 .num_nodes (), graph2 .num_nodes ()) # Use the smallest length of the two chains as our denominator
2034
2034
sorted_pred_indices = torch .argsort (preds [:, 1 ], descending = True )
2035
2035
top_10_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = 10 )
2036
- top_25_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = 25 )
2037
- top_50_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = 50 ) if l > 50 else 0.0 # Catch short seq.
2038
2036
top_l_by_10_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = (l // 10 ))
2039
2037
top_l_by_5_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = (l // 5 ))
2040
- top_l_prec = calculate_top_k_prec (sorted_pred_indices , labels , k = l )
2038
+ top_l_recall = calculate_top_k_recall (sorted_pred_indices , labels , k = l )
2039
+ top_l_by_2_recall = calculate_top_k_recall (sorted_pred_indices , labels , k = (l // 2 ))
2040
+ top_l_by_5_recall = calculate_top_k_recall (sorted_pred_indices , labels , k = (l // 5 ))
2041
2041
2042
2042
# Calculate the protein interface prediction (PICP) loss along with additional PIP metrics
2043
2043
loss = self .loss_fn (sampled_logits , labels ) # Calculate loss of a single complex
@@ -2062,11 +2062,11 @@ def test_step(self, batch, batch_idx):
2062
2062
# Log test step metric(s)
2063
2063
self .log (f'test_ce' , loss , sync_dist = True )
2064
2064
self .log ('test_top_10_prec' , top_10_prec , sync_dist = True )
2065
- self .log ('test_top_25_prec' , top_25_prec , sync_dist = True )
2066
- self .log ('test_top_50_prec' , top_50_prec , sync_dist = True )
2067
2065
self .log ('test_top_l_by_10_prec' , top_l_by_10_prec , sync_dist = True )
2068
2066
self .log ('test_top_l_by_5_prec' , top_l_by_5_prec , sync_dist = True )
2069
- self .log ('test_top_l_prec' , top_l_prec , sync_dist = True )
2067
+ self .log ('test_top_l_recall' , top_l_recall , sync_dist = True )
2068
+ self .log ('test_top_l_by_2_recall' , top_l_by_2_recall , sync_dist = True )
2069
+ self .log ('test_top_l_by_5_recall' , top_l_by_5_recall , sync_dist = True )
2070
2070
2071
2071
return {
2072
2072
'loss' : loss ,
@@ -2082,6 +2082,9 @@ def test_step(self, batch, batch_idx):
2082
2082
'top_10_prec' : top_10_prec ,
2083
2083
'top_l_by_10_prec' : top_l_by_10_prec ,
2084
2084
'top_l_by_5_prec' : top_l_by_5_prec ,
2085
+ 'top_l_recall' : top_l_recall ,
2086
+ 'top_l_by_2_recall' : top_l_by_2_recall ,
2087
+ 'top_l_by_5_recall' : top_l_by_5_recall ,
2085
2088
'target' : filepaths [0 ].split (os .sep )[- 1 ][:4 ]
2086
2089
}
2087
2090
@@ -2112,17 +2115,20 @@ def test_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT):
2112
2115
test_preds_rounded = [(output_dict ['test_preds_rounded' ]) for output_dict in outputs ]
2113
2116
test_labels = [output_dict ['test_labels' ] for output_dict in outputs ]
2114
2117
2115
- # Write out test top-k precision results to CSV
2116
- prec_data = {
2118
+ # Write out test top-k metric results to CSV
2119
+ metrics_data = {
2117
2120
'top_10_prec' : [extract_object (output_dict ['top_10_prec' ]) for output_dict in outputs ],
2118
2121
'top_l_by_10_prec' : [extract_object (output_dict ['top_l_by_10_prec' ]) for output_dict in outputs ],
2119
2122
'top_l_by_5_prec' : [extract_object (output_dict ['top_l_by_5_prec' ]) for output_dict in outputs ],
2123
+ 'top_l_recall' : [extract_object (output_dict ['top_l_recall' ]) for output_dict in outputs ],
2124
+ 'top_l_by_2_recall' : [extract_object (output_dict ['top_l_by_2_recall' ]) for output_dict in outputs ],
2125
+ 'top_l_by_5_recall' : [extract_object (output_dict ['top_l_by_5_recall' ]) for output_dict in outputs ],
2120
2126
'target' : [extract_object (output_dict ['target' ]) for output_dict in outputs ],
2121
2127
}
2122
- prec_df = pd .DataFrame (data = prec_data )
2123
- prec_df_name_prefix = 'casp_capri' if self .testing_with_casp_capri else 'dips_plus_test'
2124
- prec_df_name = prec_df_name_prefix + '_top_prec .csv'
2125
- prec_df .to_csv (prec_df_name )
2128
+ metrics_df = pd .DataFrame (data = metrics_data )
2129
+ metrics_df_name_prefix = 'casp_capri' if self .testing_with_casp_capri else 'dips_plus_test'
2130
+ metrics_df_name = metrics_df_name_prefix + '_top_metrics .csv'
2131
+ metrics_df .to_csv (metrics_df_name )
2126
2132
2127
2133
if not self .testing_with_casp_capri : # Testing with DIPS-Plus
2128
2134
# Filter out all but the first 55 test predictions and labels to reduce storage requirements
0 commit comments