--- a +++ b/ecg_classification/meter.py @@ -0,0 +1,49 @@ +import os +import itertools +import time +import random + +import numpy as np +import torch +from sklearn.metrics import accuracy_score, auc, f1_score, precision_score, recall_score + + +class Meter: + def __init__(self, n_classes=5): + self.metrics = {} + self.confusion = torch.zeros((n_classes, n_classes)) + + def update(self, x, y, loss): + x = np.argmax(x.detach().cpu().numpy(), axis=1) + y = y.detach().cpu().numpy() + self.metrics['loss'] += loss + self.metrics['accuracy'] += accuracy_score(x,y) + self.metrics['f1'] += f1_score(x,y,average='macro') + self.metrics['precision'] += precision_score(x, y, average='macro', zero_division=1) + self.metrics['recall'] += recall_score(x,y, average='macro', zero_division=1) + + self._compute_cm(x, y) + + def _compute_cm(self, x, y): + for prob, target in zip(x, y): + if prob == target: + self.confusion[target][target] += 1 + else: + self.confusion[target][prob] += 1 + + def init_metrics(self): + self.metrics['loss'] = 0 + self.metrics['accuracy'] = 0 + self.metrics['f1'] = 0 + self.metrics['precision'] = 0 + self.metrics['recall'] = 0 + + def get_metrics(self): + return self.metrics + + def get_confusion_matrix(self): + return self.confusion + + +if __name__ == '__main__': + meter = Meter()