Diff of /src/metrics.py [000000] .. [f45789]

Switch to unified view

a b/src/metrics.py
1
import os
2
import numpy as np
3
import matplotlib.pyplot as plt
4
from sklearn.metrics import (precision_recall_curve as pr_curve,
5
                             confusion_matrix,
6
                             precision_score,
7
                             accuracy_score,
8
                             recall_score,
9
                             roc_curve,
10
                             auc)
11
12
def roc_auc(ground_truth, inferences, experiment_dir):
13
    fpr, tpr, threshold = roc_curve(ground_truth, inferences)
14
    roc_auc = auc(fpr, tpr).item()
15
    plt.title('Receiver Operating Characteristic')
16
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.3f' % roc_auc)
17
    plt.legend(loc = 'lower right')
18
    plt.plot([0, 1], [0, 1],'r--')
19
    plt.xlim([0, 1])
20
    plt.ylim([0, 1])
21
    plt.ylabel('True Positive Rate')
22
    plt.xlabel('False Positive Rate')
23
    plt.savefig(os.path.join(experiment_dir, 'roc_auc.png'))
24
    plt.clf()
25
    return round(roc_auc, 3)
26
27
28
def pr_auc(ground_truth, inferences, experiment_dir):
29
    precision, recall, threshold = pr_curve(ground_truth, inferences)
30
    #precision, recall, threshold = sorted(zip(precision, recall))
31
    pr_auc = auc(recall, precision).item()
32
    plt.title('Precision - Recall')
33
    plt.plot(recall, precision, 'b', label = 'AUC = %0.3f' % pr_auc)
34
    plt.legend(loc = 'lower right')
35
    plt.xlim([0, 1])
36
    plt.ylim([0, 1])
37
    plt.ylabel('Precision')
38
    plt.xlabel('Recall')
39
    plt.savefig(os.path.join(experiment_dir, 'pr_auc.png'))
40
    plt.clf()
41
    return round(pr_auc, 3)
42
43
44
def calc_metrics(ground_truth, inferences, normalize=None, threshold=0.5):
45
    y_pred = inferences > threshold
46
    #accuracy = accuracy_score(y_true=ground_truth, y_pred=y_pred).item()
47
    #recall = recall_score(y_true=ground_truth, y_pred=y_pred).item()
48
    #precision = precision_score(y_true=ground_truth, y_pred=y_pred).item()
49
50
    conf_mat = confusion_matrix(y_true=ground_truth,
51
                                y_pred=y_pred,
52
                                normalize=normalize)
53
    tn, fp, fn, tp = list(map(lambda x: x.item(), conf_mat.ravel()))
54
55
    tpr = tp / (tp + fn)
56
    tnr = tn / (tn + fp)
57
    ppv = tp / (tp + fp)
58
    npv = tn / (tn + fn)
59
    acc = (tp+tn) / (tp+tn+fp+fn)
60
61
    metrics = {'threshold': threshold,
62
               'TPR': tpr,
63
               'TNR': tnr,
64
               'PPV': ppv,
65
               'NPV': npv,
66
               'ACC': acc}
67
    return {k: round(v, 3) for k,v in metrics.items()}