|
a |
|
b/metrics.py |
|
|
1 |
# Adapted from score written by wkentaro |
|
|
2 |
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
class runningScore(object): |
|
|
7 |
def __init__(self, n_classes): |
|
|
8 |
self.n_classes = n_classes |
|
|
9 |
self.confusion_matrix = np.zeros((n_classes, n_classes)) |
|
|
10 |
def _fast_hist(self, label_true, label_pred, n_class): |
|
|
11 |
mask = (label_true >= 0) & (label_true < n_class) |
|
|
12 |
hist = np.bincount(n_class*label_true[mask].astype(int)+label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) |
|
|
13 |
return hist |
|
|
14 |
def update(self, label_trues, label_preds): |
|
|
15 |
for lt, lp in zip(label_trues, label_preds): |
|
|
16 |
self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) |
|
|
17 |
def get_scores(self): |
|
|
18 |
hist = self.confusion_matrix |
|
|
19 |
acc = np.diag(hist).sum() / hist.sum() |
|
|
20 |
acc_cls = np.diag(hist) / hist.sum(axis=1) |
|
|
21 |
acc_cls = np.nanmean(acc_cls) |
|
|
22 |
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) |
|
|
23 |
dice=np.divide(np.multiply(iu,2),np.add(iu,1)) |
|
|
24 |
mean_iu = np.nanmean(iu[1:9]) |
|
|
25 |
mean_dice=(mean_iu*2)/(mean_iu+1) |
|
|
26 |
freq = hist.sum(axis=1) / hist.sum() |
|
|
27 |
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() |
|
|
28 |
cls_iu = dict(zip(range(self.n_classes), iu)) |
|
|
29 |
|
|
|
30 |
return {#'Overall Acc: \t': acc, |
|
|
31 |
#'Mean Acc : \t': acc_cls, |
|
|
32 |
#'FreqW Acc : \t': fwavacc, |
|
|
33 |
'Dice : \t': dice, |
|
|
34 |
'Mean Dice : \t': mean_dice,}, cls_iu |
|
|
35 |
def reset(self): |
|
|
36 |
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) |