|
a |
|
b/trainers/Metrics.py |
|
|
1 |
import csv |
|
|
2 |
import math |
|
|
3 |
|
|
|
4 |
import matplotlib |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
import numpy as np |
|
|
7 |
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def xfrange(start, stop, step): |
|
|
11 |
i = 0 |
|
|
12 |
while start + i * step < stop: |
|
|
13 |
yield start + i * step |
|
|
14 |
i += 1 |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def compute_prc(predictions, labels, filename=None, plottitle="Precision-Recall Curve"): |
|
|
18 |
precisions, recalls, thresholds = precision_recall_curve(labels.astype(int), predictions) |
|
|
19 |
auprc = average_precision_score(labels.astype(int), predictions) |
|
|
20 |
|
|
|
21 |
fig = matplotlib.pyplot.figure() |
|
|
22 |
matplotlib.pyplot.step(recalls, precisions, color='b', alpha=0.2, where='post') |
|
|
23 |
matplotlib.pyplot.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') |
|
|
24 |
matplotlib.pyplot.xlabel('Recall') |
|
|
25 |
matplotlib.pyplot.ylabel('Precision') |
|
|
26 |
matplotlib.pyplot.ylim([0.0, 1.05]) |
|
|
27 |
matplotlib.pyplot.xlim([0.0, 1.0]) |
|
|
28 |
matplotlib.pyplot.title(f'{plottitle} (area = {auprc:.2f}.)') |
|
|
29 |
matplotlib.pyplot.show() |
|
|
30 |
|
|
|
31 |
# save a pdf to disk |
|
|
32 |
if filename: |
|
|
33 |
fig.savefig(filename) |
|
|
34 |
|
|
|
35 |
with open(filename + ".csv", mode="w") as csv_file: |
|
|
36 |
fieldnames = ["Precision", "Recall"] |
|
|
37 |
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) |
|
|
38 |
writer.writeheader() |
|
|
39 |
for i in range(len(precisions)): |
|
|
40 |
writer.writerow({"Precision": precisions[i], "Recall": recalls[i]}) |
|
|
41 |
|
|
|
42 |
return auprc, precisions, recalls, thresholds |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def compute_roc(predictions, labels, filename=None, plottitle="ROC Curve"): |
|
|
46 |
_fpr, _tpr, _ = roc_curve(labels.astype(int), predictions) |
|
|
47 |
roc_auc = auc(_fpr, _tpr) |
|
|
48 |
|
|
|
49 |
fig = matplotlib.pyplot.figure() |
|
|
50 |
matplotlib.pyplot.plot(_fpr, _tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc) |
|
|
51 |
matplotlib.pyplot.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') |
|
|
52 |
matplotlib.pyplot.xlim([0.0, 1.0]) |
|
|
53 |
matplotlib.pyplot.ylim([0.0, 1.05]) |
|
|
54 |
matplotlib.pyplot.xlabel('False Positive Rate') |
|
|
55 |
matplotlib.pyplot.ylabel('True Positive Rate') |
|
|
56 |
matplotlib.pyplot.title(plottitle) |
|
|
57 |
matplotlib.pyplot.legend(loc="lower right") |
|
|
58 |
matplotlib.pyplot.show() |
|
|
59 |
|
|
|
60 |
# save a pdf to disk |
|
|
61 |
if filename: |
|
|
62 |
fig.savefig(filename) |
|
|
63 |
|
|
|
64 |
return roc_auc, _fpr, _tpr, _ |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def dice(P, G): |
|
|
68 |
psum = np.sum(P.flatten()) |
|
|
69 |
gsum = np.sum(G.flatten()) |
|
|
70 |
pgsum = np.sum(np.multiply(P.flatten(), G.flatten())) |
|
|
71 |
score = (2 * pgsum) / (psum + gsum) |
|
|
72 |
return score |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def confusion_matrix(P, G): |
|
|
76 |
tp = np.sum(np.multiply(P.flatten(), G.flatten())) |
|
|
77 |
fp = np.sum(np.multiply(P.flatten(), np.invert(G.flatten()))) |
|
|
78 |
fn = np.sum(np.multiply(np.invert(P.flatten()), G.flatten())) |
|
|
79 |
tn = np.sum(np.multiply(np.invert(P.flatten()), np.invert(G.flatten()))) |
|
|
80 |
return tp, fp, tn, fn |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
def tpr(P, G): |
|
|
84 |
tp = np.sum(np.multiply(P.flatten(), G.flatten())) |
|
|
85 |
fn = np.sum(np.multiply(np.invert(P.flatten()), G.flatten())) |
|
|
86 |
return tp / (tp + fn) |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
def fpr(P, G): |
|
|
90 |
tn = np.sum(np.multiply(np.invert(P.flatten()), np.invert(G.flatten()))) |
|
|
91 |
fp = np.sum(np.multiply(P.flatten(), np.invert(G.flatten()))) |
|
|
92 |
return fp / (fp + tn) |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
def precision(P, G): |
|
|
96 |
tp = np.sum(np.multiply(P.flatten(), G.flatten())) |
|
|
97 |
fp = np.sum(np.multiply(P.flatten(), np.invert(G.flatten()))) |
|
|
98 |
return tp / (tp + fp) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
def recall(P, G): |
|
|
102 |
return tpr(P, G) |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
def vd(P, G): |
|
|
106 |
tps = np.multiply(P.flatten(), G.flatten()) |
|
|
107 |
return np.sum(np.abs(np.logical_xor(tps, G.flatten()))) / np.sum(G.flatten()) |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def compute_dice_curve_recursive(predictions, labels, filename=None, plottitle="DICE Curve", granularity=5): |
|
|
111 |
scores, threshs = compute_dice_score(predictions, labels, granularity) |
|
|
112 |
|
|
|
113 |
best_score, best_threshold = sorted(zip(scores, threshs), reverse=True)[0] |
|
|
114 |
|
|
|
115 |
min_threshs, max_threshs = min(threshs), max(threshs) |
|
|
116 |
buffer_range = math.fabs(min_threshs - max_threshs) * 0.02 |
|
|
117 |
x_min, x_max = min(threshs) - buffer_range, max(threshs) + buffer_range |
|
|
118 |
fig = matplotlib.pyplot.figure() |
|
|
119 |
matplotlib.pyplot.plot(threshs, scores, color='darkorange', lw=2, label='DICE vs Threshold Curve') |
|
|
120 |
matplotlib.pyplot.xlim([x_min, x_max]) |
|
|
121 |
matplotlib.pyplot.ylim([0.0, 1.05]) |
|
|
122 |
matplotlib.pyplot.xlabel('Thresholds') |
|
|
123 |
matplotlib.pyplot.ylabel('DICE Score') |
|
|
124 |
matplotlib.pyplot.title(plottitle) |
|
|
125 |
matplotlib.pyplot.legend(loc="lower right") |
|
|
126 |
matplotlib.pyplot.text(x_max - x_max * 0.01, 1, f'Best dice score at {best_threshold:.5f} with {best_score:.4f}', horizontalalignment='right', |
|
|
127 |
verticalalignment='top') |
|
|
128 |
matplotlib.pyplot.show() |
|
|
129 |
|
|
|
130 |
# save a pdf to disk |
|
|
131 |
if filename: |
|
|
132 |
fig.savefig(filename) |
|
|
133 |
|
|
|
134 |
bestthresh_idx = np.argmax(scores) |
|
|
135 |
return scores[bestthresh_idx], threshs[bestthresh_idx] |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
def compute_dice_score(predictions, labels, granularity): |
|
|
139 |
def inner_compute_dice_curve_recursive(start, stop, decimal): |
|
|
140 |
_threshs = [] |
|
|
141 |
_scores = [] |
|
|
142 |
had_recursion = False |
|
|
143 |
|
|
|
144 |
if decimal == granularity: |
|
|
145 |
return _threshs, _scores |
|
|
146 |
|
|
|
147 |
for i, t in enumerate(xfrange(start, stop, (1.0 / (10.0 ** decimal)))): |
|
|
148 |
score = dice(np.where(predictions > t, 1, 0), labels) |
|
|
149 |
if i >= 2 and score <= _scores[i - 1] and not had_recursion: |
|
|
150 |
_subthreshs, _subscores = inner_compute_dice_curve_recursive(_threshs[i - 2], t, decimal + 1) |
|
|
151 |
_threshs.extend(_subthreshs) |
|
|
152 |
_scores.extend(_subscores) |
|
|
153 |
had_recursion = True |
|
|
154 |
_scores.append(score) |
|
|
155 |
_threshs.append(t) |
|
|
156 |
|
|
|
157 |
return _threshs, _scores |
|
|
158 |
|
|
|
159 |
threshs, scores = inner_compute_dice_curve_recursive(0, 1.0, 1) |
|
|
160 |
sorted_pairs = sorted(zip(threshs, scores)) |
|
|
161 |
threshs, scores = list(zip(*sorted_pairs)) |
|
|
162 |
return scores, threshs |
|
|
163 |
|
|
|
164 |
|
|
|
165 |
# Predictive pixel-wise variance combining aleatoric and epistemic model uncertainty |
|
|
166 |
# As seen in "What Uncertainties Do we Need in Bayesian Deep Learning for Computer Vision" |
|
|
167 |
# p is a tensor of n monte carlo regression results |
|
|
168 |
# sigma is the same for variances predicted by the network |
|
|
169 |
# axis defines the index of the axis which stores the monte carlo samples |
|
|
170 |
def combined_predictive_uncertainty(p, sigmas, axis=-1, log_var=False): |
|
|
171 |
if log_var: |
|
|
172 |
sigmas = np.exp(sigmas) |
|
|
173 |
return np.mean(np.square(p), axis=axis) - np.square(np.mean(p, axis=axis)) + np.mean(sigmas, axis=axis) |