Diff of /utils/evalMetrics.py [000000] .. [f2ca4d]

Switch to side-by-side view

--- a
+++ b/utils/evalMetrics.py
@@ -0,0 +1,83 @@
+import numpy as np
+import sys
+import scipy.spatial
+from sklearn.metrics import f1_score
+from sklearn.metrics import precision_recall_fscore_support as score
+
+#evaluation functions
+def metricEval(eval_metric, output, gt, num_labels):
+    if eval_metric == 'iou':
+        return get_iou(output.squeeze(), gt.squeeze(), num_labels)
+    elif eval_metric == 'dice':
+        return get_dice(output.squeeze(), gt.squeeze(), num_labels)
+    elif eval_metric == 'recall':
+        return get_recall(output.squeeze(), gt.squeeze(), num_labels)
+    elif eval_metric == 'precision':
+        return get_precision(output.squeeze(), gt.squeeze(), num_labels)
+    else:
+        print('Invalid evaluation metric value')
+        sys.exit()
+    print('MY IOU', get_iou(output.squeeze(), gt.squeeze(), num_labels))
+    print('MY DICE', get_dice(output.squeeze(), gt.squeeze(), num_labels))
+    print('MY recll', get_recall(output.squeeze(), gt.squeeze(), num_labels))
+    print('MY PRECISION' , get_precision(output.squeeze(), gt.squeeze(), num_labels))
+    print(precision_recall_fscore_support(gt.reshape(-1), output.reshape(-1)))
+
+def get_iou(pred, gt, num_labels):
+    if pred.shape != gt.shape:
+        print('pred shape',pred.shape, 'gt shape', gt.shape)
+    assert(pred.shape == gt.shape)
+    gt = gt.astype(np.float32)
+    pred = pred.astype(np.float32)
+
+    gt = gt.reshape(-1)
+    pred = pred.reshape(-1)
+
+    max_label = num_labels-1
+    count = np.zeros((max_label+1,))
+    for j in range(max_label+1):
+        gt_loc = set(np.where(gt == j)[0])
+        pred_loc = set(np.where(pred == j)[0])
+
+        intersection = set.intersection(gt_loc, pred_loc)
+        union = set.union(gt_loc, pred_loc)
+
+        if len(gt_loc) != 0:
+            count[j] = float(len(intersection)) / float(len(union))
+    return np.sum(count) / float(num_labels)
+
+def get_dice(pred, gt, num_labels):
+    if num_labels != 2:
+        print('Dice evaluation score is only implemented for 2 labels')
+        sys.exit()
+    return 1.0 - scipy.spatial.distance.dice(pred.reshape(-1), gt.reshape(-1))
+
+#f1 score at beta = 1 is the same as dice score
+
+# recall = (num detected WMH) / (num true WMH)
+def get_recall(pred, gt, num_labels):
+    if num_labels != 2:
+        sys.exit()
+
+    gt = gt.reshape(-1)
+    pred = pred.reshape(-1)
+
+    gt_loc = set(np.where(gt == 1)[0])
+    pred_loc = set(np.where(pred == 1)[0])
+    TP = float(len(set.intersection(gt_loc, pred_loc)))
+    TPandFN = float(len(gt_loc))
+    return TP / TPandFN
+
+# precision = (number detected WMH) / (number of all detections)
+def get_precision(pred, gt, num_labels):
+    if num_labels != 2:
+        sys.exit()
+
+    gt = gt.reshape(-1)
+    pred = pred.reshape(-1)
+
+    gt_loc = set(np.where(gt == 1)[0])
+    pred_loc = set(np.where(pred == 1)[0])
+    TP = float(len(set.intersection(gt_loc, pred_loc)))
+    TPandFP = float(len(pred_loc))
+    return TP / TPandFP
\ No newline at end of file