Switch to unified view

a b/ecg_classification/meter.py
1
import os
2
import itertools
3
import time
4
import random
5
6
import numpy as np 
7
import torch
8
from sklearn.metrics import accuracy_score, auc, f1_score, precision_score, recall_score
9
10
11
class Meter:
12
    def __init__(self, n_classes=5):
13
        self.metrics = {}
14
        self.confusion = torch.zeros((n_classes, n_classes))
15
    
16
    def update(self, x, y, loss):
17
        x = np.argmax(x.detach().cpu().numpy(), axis=1)
18
        y = y.detach().cpu().numpy()
19
        self.metrics['loss'] += loss
20
        self.metrics['accuracy'] += accuracy_score(x,y)
21
        self.metrics['f1'] += f1_score(x,y,average='macro')
22
        self.metrics['precision'] += precision_score(x, y, average='macro', zero_division=1)
23
        self.metrics['recall'] += recall_score(x,y, average='macro', zero_division=1)
24
        
25
        self._compute_cm(x, y)
26
        
27
    def _compute_cm(self, x, y):
28
        for prob, target in zip(x, y):
29
            if prob == target:
30
                self.confusion[target][target] += 1
31
            else:
32
                self.confusion[target][prob] += 1
33
    
34
    def init_metrics(self):
35
        self.metrics['loss'] = 0
36
        self.metrics['accuracy'] = 0
37
        self.metrics['f1'] = 0
38
        self.metrics['precision'] = 0
39
        self.metrics['recall'] = 0
40
        
41
    def get_metrics(self):
42
        return self.metrics
43
    
44
    def get_confusion_matrix(self):
45
        return self.confusion
46
47
   
48
if __name__ == '__main__':
49
    meter = Meter()