[f2ca4d]: / utils / evalMetrics.py

Download this file

83 lines (69 with data), 2.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import sys
import scipy.spatial
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support as score
#evaluation functions
def metricEval(eval_metric, output, gt, num_labels):
if eval_metric == 'iou':
return get_iou(output.squeeze(), gt.squeeze(), num_labels)
elif eval_metric == 'dice':
return get_dice(output.squeeze(), gt.squeeze(), num_labels)
elif eval_metric == 'recall':
return get_recall(output.squeeze(), gt.squeeze(), num_labels)
elif eval_metric == 'precision':
return get_precision(output.squeeze(), gt.squeeze(), num_labels)
else:
print('Invalid evaluation metric value')
sys.exit()
print('MY IOU', get_iou(output.squeeze(), gt.squeeze(), num_labels))
print('MY DICE', get_dice(output.squeeze(), gt.squeeze(), num_labels))
print('MY recll', get_recall(output.squeeze(), gt.squeeze(), num_labels))
print('MY PRECISION' , get_precision(output.squeeze(), gt.squeeze(), num_labels))
print(precision_recall_fscore_support(gt.reshape(-1), output.reshape(-1)))
def get_iou(pred, gt, num_labels):
if pred.shape != gt.shape:
print('pred shape',pred.shape, 'gt shape', gt.shape)
assert(pred.shape == gt.shape)
gt = gt.astype(np.float32)
pred = pred.astype(np.float32)
gt = gt.reshape(-1)
pred = pred.reshape(-1)
max_label = num_labels-1
count = np.zeros((max_label+1,))
for j in range(max_label+1):
gt_loc = set(np.where(gt == j)[0])
pred_loc = set(np.where(pred == j)[0])
intersection = set.intersection(gt_loc, pred_loc)
union = set.union(gt_loc, pred_loc)
if len(gt_loc) != 0:
count[j] = float(len(intersection)) / float(len(union))
return np.sum(count) / float(num_labels)
def get_dice(pred, gt, num_labels):
if num_labels != 2:
print('Dice evaluation score is only implemented for 2 labels')
sys.exit()
return 1.0 - scipy.spatial.distance.dice(pred.reshape(-1), gt.reshape(-1))
#f1 score at beta = 1 is the same as dice score
# recall = (num detected WMH) / (num true WMH)
def get_recall(pred, gt, num_labels):
if num_labels != 2:
sys.exit()
gt = gt.reshape(-1)
pred = pred.reshape(-1)
gt_loc = set(np.where(gt == 1)[0])
pred_loc = set(np.where(pred == 1)[0])
TP = float(len(set.intersection(gt_loc, pred_loc)))
TPandFN = float(len(gt_loc))
return TP / TPandFN
# precision = (number detected WMH) / (number of all detections)
def get_precision(pred, gt, num_labels):
if num_labels != 2:
sys.exit()
gt = gt.reshape(-1)
pred = pred.reshape(-1)
gt_loc = set(np.where(gt == 1)[0])
pred_loc = set(np.where(pred == 1)[0])
TP = float(len(set.intersection(gt_loc, pred_loc)))
TPandFP = float(len(pred_loc))
return TP / TPandFP