Diff of /utilities/metricUtils.py [000000] .. [a18f15]

Switch to unified view

a b/utilities/metricUtils.py
1
import os, csv
2
import numpy as np
3
import torch
4
from sklearn import metrics as skmetrics
5
import matplotlib.pyplot as plt
6
7
class MultiClassMetrics():
8
    def __init__(self, logpath):
9
        self.tgt = []
10
        self.prd = []
11
        self.nnloss = []
12
        self.logpath = logpath
13
14
    def reset(self, save_results = False):
15
        if save_results: self._write_predictions()
16
        self.__init__(self.logpath)
17
18
    def add_entry(self, prd, tgt, loss=0):
19
        self.prd.extend(prd.cpu().detach().numpy())
20
        self.tgt.extend(tgt.cpu().detach().numpy())
21
        if loss: self.nnloss.append(loss.cpu().detach().numpy())
22
23
    def get_loss(self):
24
        return sum(self.nnloss) / len(self.nnloss)
25
26
    def get_accuracy(self):
27
        return skmetrics.accuracy_score(self.tgt, self.prd)
28
29
    def get_balanced_accuracy(self):
30
        return skmetrics.balanced_accuracy_score(self.tgt, self.prd)
31
32
    def get_f1score(self):
33
        return skmetrics.f1_score(self.tgt, self.prd, average='macro')
34
35
    def get_class_report(self):
36
        return skmetrics.classification_report(self.tgt, self.prd,
37
                    output_dict= True)
38
39
    def get_confusion_matrix(self, save_png = False, title=""):
40
        lbls = sorted(list(set(self.tgt)))
41
        cm = skmetrics.confusion_matrix(self.tgt, self.prd,
42
                                labels= lbls)
43
        if save_png:
44
            disp = skmetrics.ConfusionMatrixDisplay(confusion_matrix=cm,
45
                                        display_labels=lbls).plot()
46
            plt.savefig(self.logpath+f'/{title}Confusion.png', bbox_inches='tight')
47
        return cm
48
49
    def _write_predictions(self, title=""):
50
        with open(os.path.join(self.logpath, f"{title}Predict.csv"), 'w') as f:
51
            writer = csv.writer(f)
52
            writer.writerow(["target", "prediction"])
53
            writer.writerows(zip(self.tgt, self.prd))
54
55
56
57
if __name__ == "__main__":
58
59
    obj = MultiClassMetrics()
60
    obj.tgt = [1,1,1,2,2,2,3,3,3,4,4,4,5,5,5]
61
    obj.prd = [1,1,2,2,2,3,3,3,4,4,4,5,5,5,1]
62
63
    print(obj.get_class_report())