a b/metrics.py
1
# Adapted from score written by wkentaro
2
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3
4
import numpy as np
5
6
class runningScore(object):
7
    def __init__(self, n_classes):
8
        self.n_classes = n_classes
9
        self.confusion_matrix = np.zeros((n_classes, n_classes))
10
    def _fast_hist(self, label_true, label_pred, n_class):
11
        mask = (label_true >= 0) & (label_true < n_class)
12
        hist = np.bincount(n_class*label_true[mask].astype(int)+label_pred[mask], minlength=n_class**2).reshape(n_class, n_class)
13
        return hist
14
    def update(self, label_trues, label_preds):
15
        for lt, lp in zip(label_trues, label_preds):
16
            self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
17
    def get_scores(self):
18
        hist = self.confusion_matrix
19
        acc = np.diag(hist).sum() / hist.sum()
20
        acc_cls = np.diag(hist) / hist.sum(axis=1)
21
        acc_cls = np.nanmean(acc_cls)
22
        iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
23
        dice=np.divide(np.multiply(iu,2),np.add(iu,1))
24
        mean_iu = np.nanmean(iu[1:9])
25
        mean_dice=(mean_iu*2)/(mean_iu+1)
26
        freq = hist.sum(axis=1) / hist.sum()
27
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
28
        cls_iu = dict(zip(range(self.n_classes), iu))
29
        
30
        return {#'Overall Acc: \t': acc,
31
                #'Mean Acc : \t': acc_cls,
32
                #'FreqW Acc : \t': fwavacc,
33
                'Dice : \t': dice,
34
                'Mean Dice : \t': mean_dice,}, cls_iu
35
    def reset(self):
36
        self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))