[f45789]: / src / metrics.py

Download this file

67 lines (59 with data), 2.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (precision_recall_curve as pr_curve,
confusion_matrix,
precision_score,
accuracy_score,
recall_score,
roc_curve,
auc)
def roc_auc(ground_truth, inferences, experiment_dir):
fpr, tpr, threshold = roc_curve(ground_truth, inferences)
roc_auc = auc(fpr, tpr).item()
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.3f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.savefig(os.path.join(experiment_dir, 'roc_auc.png'))
plt.clf()
return round(roc_auc, 3)
def pr_auc(ground_truth, inferences, experiment_dir):
precision, recall, threshold = pr_curve(ground_truth, inferences)
#precision, recall, threshold = sorted(zip(precision, recall))
pr_auc = auc(recall, precision).item()
plt.title('Precision - Recall')
plt.plot(recall, precision, 'b', label = 'AUC = %0.3f' % pr_auc)
plt.legend(loc = 'lower right')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('Precision')
plt.xlabel('Recall')
plt.savefig(os.path.join(experiment_dir, 'pr_auc.png'))
plt.clf()
return round(pr_auc, 3)
def calc_metrics(ground_truth, inferences, normalize=None, threshold=0.5):
y_pred = inferences > threshold
#accuracy = accuracy_score(y_true=ground_truth, y_pred=y_pred).item()
#recall = recall_score(y_true=ground_truth, y_pred=y_pred).item()
#precision = precision_score(y_true=ground_truth, y_pred=y_pred).item()
conf_mat = confusion_matrix(y_true=ground_truth,
y_pred=y_pred,
normalize=normalize)
tn, fp, fn, tp = list(map(lambda x: x.item(), conf_mat.ravel()))
tpr = tp / (tp + fn)
tnr = tn / (tn + fp)
ppv = tp / (tp + fp)
npv = tn / (tn + fn)
acc = (tp+tn) / (tp+tn+fp+fn)
metrics = {'threshold': threshold,
'TPR': tpr,
'TNR': tnr,
'PPV': ppv,
'NPV': npv,
'ACC': acc}
return {k: round(v, 3) for k,v in metrics.items()}