1- """Logic file"""
1+ """Metrics file"""
22import argparse
33import glob
44import yaml
88
99
1010def dice_coef_metric (
11- probabilities : np .ndarray , truth : np .ndarray , treshold : float = 0.5 , eps : float = 0
11+ predictions : np .ndarray , truth : np .ndarray , treshold : float = 0.5 , eps : float = 0
1212) -> np .ndarray :
1313 """
1414 Calculate Dice score for data batch.
1515 Params:
16- probobilities : model outputs after activation function.
16+ predictions : model outputs after activation function.
1717 truth: truth values.
18- threshold: threshold for probabilities .
18+ threshold: threshold for predictions .
1919 eps: additive to refine the estimate.
2020 Returns: dice score aka f1.
2121 """
22+
2223 scores = []
23- num = probabilities .shape [0 ]
24- predictions = probabilities >= treshold
24+ num = predictions .shape [0 ]
25+ predictions = predictions >= treshold
2526 assert predictions .shape == truth .shape
2627 for i in range (num ):
2728 prediction = predictions [i ]
@@ -36,20 +37,21 @@ def dice_coef_metric(
3637
3738
3839def jaccard_coef_metric (
39- probabilities : np .ndarray , truth : np .ndarray , treshold : float = 0.5 , eps : float = 0
40+ predictions : np .ndarray , truth : np .ndarray , treshold : float = 0.5 , eps : float = 0
4041) -> np .ndarray :
4142 """
4243 Calculate Jaccard index for data batch.
4344 Params:
44- probobilities : model outputs after activation function.
45+ predictions : model outputs after activation function.
4546 truth: truth values.
46- threshold: threshold for probabilities .
47+ threshold: threshold for predictions .
4748 eps: additive to refine the estimate.
4849 Returns: jaccard score aka iou."
4950 """
51+
5052 scores = []
51- num = probabilities .shape [0 ]
52- predictions = probabilities >= treshold
53+ num = predictions .shape [0 ]
54+ predictions = predictions >= treshold
5355 assert predictions .shape == truth .shape
5456
5557 for i in range (num ):
@@ -65,6 +67,7 @@ def jaccard_coef_metric(
6567
6668
6769def preprocess_mask_labels (mask : np .ndarray ):
70+ """Preprocess the mask labels from a numpy array"""
6871
6972 mask_WT = mask .copy ()
7073 mask_WT [mask_WT == 1 ] = 1
@@ -88,12 +91,17 @@ def preprocess_mask_labels(mask: np.ndarray):
8891
8992
9093def load_img (file_path ):
94+ """Reads segmentations image as a numpy array"""
95+
9196 data = nib .load (file_path )
9297 data = np .asarray (data .dataobj )
9398 return data
9499
95100
96101def get_data_arr (predictions_path , ground_truth_path ):
102+ """Reads the content for the predictions and ground truth folders
103+ and then returns the data in numpy array format"""
104+
97105 predictions = glob .glob (predictions_path + "/*" )
98106 ground_truth = glob .glob (ground_truth_path + "/*" )
99107 if not len (predictions ) == len (ground_truth ):
@@ -114,11 +122,14 @@ def get_data_arr(predictions_path, ground_truth_path):
114122
115123
116124def create_metrics_file (output_file , results ):
125+ """Writes metrics to an output yaml file"""
117126 with open (output_file , "w" ) as f :
118127 yaml .dump (results , f )
119128
120129
121130def main ():
131+ """Main function that recieves input parameters and calculate metrics"""
132+
122133 parser = argparse .ArgumentParser ()
123134 parser .add_argument (
124135 "--ground_truth" ,
0 commit comments