Diff of /trainers/Metrics.py [000000] .. [978658]

Switch to unified view

a b/trainers/Metrics.py
1
import csv
2
import math
3
4
import matplotlib
5
import matplotlib.pyplot as plt
6
import numpy as np
7
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
8
9
10
def xfrange(start, stop, step):
11
    i = 0
12
    while start + i * step < stop:
13
        yield start + i * step
14
        i += 1
15
16
17
def compute_prc(predictions, labels, filename=None, plottitle="Precision-Recall Curve"):
18
    precisions, recalls, thresholds = precision_recall_curve(labels.astype(int), predictions)
19
    auprc = average_precision_score(labels.astype(int), predictions)
20
21
    fig = matplotlib.pyplot.figure()
22
    matplotlib.pyplot.step(recalls, precisions, color='b', alpha=0.2, where='post')
23
    matplotlib.pyplot.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
24
    matplotlib.pyplot.xlabel('Recall')
25
    matplotlib.pyplot.ylabel('Precision')
26
    matplotlib.pyplot.ylim([0.0, 1.05])
27
    matplotlib.pyplot.xlim([0.0, 1.0])
28
    matplotlib.pyplot.title(f'{plottitle} (area = {auprc:.2f}.)')
29
    matplotlib.pyplot.show()
30
31
    # save a pdf to disk
32
    if filename:
33
        fig.savefig(filename)
34
35
        with open(filename + ".csv", mode="w") as csv_file:
36
            fieldnames = ["Precision", "Recall"]
37
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
38
            writer.writeheader()
39
            for i in range(len(precisions)):
40
                writer.writerow({"Precision": precisions[i], "Recall": recalls[i]})
41
42
    return auprc, precisions, recalls, thresholds
43
44
45
def compute_roc(predictions, labels, filename=None, plottitle="ROC Curve"):
46
    _fpr, _tpr, _ = roc_curve(labels.astype(int), predictions)
47
    roc_auc = auc(_fpr, _tpr)
48
49
    fig = matplotlib.pyplot.figure()
50
    matplotlib.pyplot.plot(_fpr, _tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
51
    matplotlib.pyplot.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
52
    matplotlib.pyplot.xlim([0.0, 1.0])
53
    matplotlib.pyplot.ylim([0.0, 1.05])
54
    matplotlib.pyplot.xlabel('False Positive Rate')
55
    matplotlib.pyplot.ylabel('True Positive Rate')
56
    matplotlib.pyplot.title(plottitle)
57
    matplotlib.pyplot.legend(loc="lower right")
58
    matplotlib.pyplot.show()
59
60
    # save a pdf to disk
61
    if filename:
62
        fig.savefig(filename)
63
64
    return roc_auc, _fpr, _tpr, _
65
66
67
def dice(P, G):
68
    psum = np.sum(P.flatten())
69
    gsum = np.sum(G.flatten())
70
    pgsum = np.sum(np.multiply(P.flatten(), G.flatten()))
71
    score = (2 * pgsum) / (psum + gsum)
72
    return score
73
74
75
def confusion_matrix(P, G):
76
    tp = np.sum(np.multiply(P.flatten(), G.flatten()))
77
    fp = np.sum(np.multiply(P.flatten(), np.invert(G.flatten())))
78
    fn = np.sum(np.multiply(np.invert(P.flatten()), G.flatten()))
79
    tn = np.sum(np.multiply(np.invert(P.flatten()), np.invert(G.flatten())))
80
    return tp, fp, tn, fn
81
82
83
def tpr(P, G):
84
    tp = np.sum(np.multiply(P.flatten(), G.flatten()))
85
    fn = np.sum(np.multiply(np.invert(P.flatten()), G.flatten()))
86
    return tp / (tp + fn)
87
88
89
def fpr(P, G):
90
    tn = np.sum(np.multiply(np.invert(P.flatten()), np.invert(G.flatten())))
91
    fp = np.sum(np.multiply(P.flatten(), np.invert(G.flatten())))
92
    return fp / (fp + tn)
93
94
95
def precision(P, G):
96
    tp = np.sum(np.multiply(P.flatten(), G.flatten()))
97
    fp = np.sum(np.multiply(P.flatten(), np.invert(G.flatten())))
98
    return tp / (tp + fp)
99
100
101
def recall(P, G):
102
    return tpr(P, G)
103
104
105
def vd(P, G):
106
    tps = np.multiply(P.flatten(), G.flatten())
107
    return np.sum(np.abs(np.logical_xor(tps, G.flatten()))) / np.sum(G.flatten())
108
109
110
def compute_dice_curve_recursive(predictions, labels, filename=None, plottitle="DICE Curve", granularity=5):
111
    scores, threshs = compute_dice_score(predictions, labels, granularity)
112
113
    best_score, best_threshold = sorted(zip(scores, threshs), reverse=True)[0]
114
115
    min_threshs, max_threshs = min(threshs), max(threshs)
116
    buffer_range = math.fabs(min_threshs - max_threshs) * 0.02
117
    x_min, x_max = min(threshs) - buffer_range, max(threshs) + buffer_range
118
    fig = matplotlib.pyplot.figure()
119
    matplotlib.pyplot.plot(threshs, scores, color='darkorange', lw=2, label='DICE vs Threshold Curve')
120
    matplotlib.pyplot.xlim([x_min, x_max])
121
    matplotlib.pyplot.ylim([0.0, 1.05])
122
    matplotlib.pyplot.xlabel('Thresholds')
123
    matplotlib.pyplot.ylabel('DICE Score')
124
    matplotlib.pyplot.title(plottitle)
125
    matplotlib.pyplot.legend(loc="lower right")
126
    matplotlib.pyplot.text(x_max - x_max * 0.01, 1, f'Best dice score at {best_threshold:.5f} with {best_score:.4f}', horizontalalignment='right',
127
                           verticalalignment='top')
128
    matplotlib.pyplot.show()
129
130
    # save a pdf to disk
131
    if filename:
132
        fig.savefig(filename)
133
134
    bestthresh_idx = np.argmax(scores)
135
    return scores[bestthresh_idx], threshs[bestthresh_idx]
136
137
138
def compute_dice_score(predictions, labels, granularity):
139
    def inner_compute_dice_curve_recursive(start, stop, decimal):
140
        _threshs = []
141
        _scores = []
142
        had_recursion = False
143
144
        if decimal == granularity:
145
            return _threshs, _scores
146
147
        for i, t in enumerate(xfrange(start, stop, (1.0 / (10.0 ** decimal)))):
148
            score = dice(np.where(predictions > t, 1, 0), labels)
149
            if i >= 2 and score <= _scores[i - 1] and not had_recursion:
150
                _subthreshs, _subscores = inner_compute_dice_curve_recursive(_threshs[i - 2], t, decimal + 1)
151
                _threshs.extend(_subthreshs)
152
                _scores.extend(_subscores)
153
                had_recursion = True
154
            _scores.append(score)
155
            _threshs.append(t)
156
157
        return _threshs, _scores
158
159
    threshs, scores = inner_compute_dice_curve_recursive(0, 1.0, 1)
160
    sorted_pairs = sorted(zip(threshs, scores))
161
    threshs, scores = list(zip(*sorted_pairs))
162
    return scores, threshs
163
164
165
# Predictive pixel-wise variance combining aleatoric and epistemic model uncertainty
166
# As seen in "What Uncertainties Do we Need in Bayesian Deep Learning for Computer Vision"
167
# p is a tensor of n monte carlo regression results
168
# sigma is the same for variances predicted by the network
169
# axis defines the index of the axis which stores the monte carlo samples
170
def combined_predictive_uncertainty(p, sigmas, axis=-1, log_var=False):
171
    if log_var:
172
        sigmas = np.exp(sigmas)
173
    return np.mean(np.square(p), axis=axis) - np.square(np.mean(p, axis=axis)) + np.mean(sigmas, axis=axis)