|
a |
|
b/src/metrics.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
from sklearn.metrics import (precision_recall_curve as pr_curve, |
|
|
5 |
confusion_matrix, |
|
|
6 |
precision_score, |
|
|
7 |
accuracy_score, |
|
|
8 |
recall_score, |
|
|
9 |
roc_curve, |
|
|
10 |
auc) |
|
|
11 |
|
|
|
12 |
def roc_auc(ground_truth, inferences, experiment_dir): |
|
|
13 |
fpr, tpr, threshold = roc_curve(ground_truth, inferences) |
|
|
14 |
roc_auc = auc(fpr, tpr).item() |
|
|
15 |
plt.title('Receiver Operating Characteristic') |
|
|
16 |
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.3f' % roc_auc) |
|
|
17 |
plt.legend(loc = 'lower right') |
|
|
18 |
plt.plot([0, 1], [0, 1],'r--') |
|
|
19 |
plt.xlim([0, 1]) |
|
|
20 |
plt.ylim([0, 1]) |
|
|
21 |
plt.ylabel('True Positive Rate') |
|
|
22 |
plt.xlabel('False Positive Rate') |
|
|
23 |
plt.savefig(os.path.join(experiment_dir, 'roc_auc.png')) |
|
|
24 |
plt.clf() |
|
|
25 |
return round(roc_auc, 3) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def pr_auc(ground_truth, inferences, experiment_dir): |
|
|
29 |
precision, recall, threshold = pr_curve(ground_truth, inferences) |
|
|
30 |
#precision, recall, threshold = sorted(zip(precision, recall)) |
|
|
31 |
pr_auc = auc(recall, precision).item() |
|
|
32 |
plt.title('Precision - Recall') |
|
|
33 |
plt.plot(recall, precision, 'b', label = 'AUC = %0.3f' % pr_auc) |
|
|
34 |
plt.legend(loc = 'lower right') |
|
|
35 |
plt.xlim([0, 1]) |
|
|
36 |
plt.ylim([0, 1]) |
|
|
37 |
plt.ylabel('Precision') |
|
|
38 |
plt.xlabel('Recall') |
|
|
39 |
plt.savefig(os.path.join(experiment_dir, 'pr_auc.png')) |
|
|
40 |
plt.clf() |
|
|
41 |
return round(pr_auc, 3) |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def calc_metrics(ground_truth, inferences, normalize=None, threshold=0.5): |
|
|
45 |
y_pred = inferences > threshold |
|
|
46 |
#accuracy = accuracy_score(y_true=ground_truth, y_pred=y_pred).item() |
|
|
47 |
#recall = recall_score(y_true=ground_truth, y_pred=y_pred).item() |
|
|
48 |
#precision = precision_score(y_true=ground_truth, y_pred=y_pred).item() |
|
|
49 |
|
|
|
50 |
conf_mat = confusion_matrix(y_true=ground_truth, |
|
|
51 |
y_pred=y_pred, |
|
|
52 |
normalize=normalize) |
|
|
53 |
tn, fp, fn, tp = list(map(lambda x: x.item(), conf_mat.ravel())) |
|
|
54 |
|
|
|
55 |
tpr = tp / (tp + fn) |
|
|
56 |
tnr = tn / (tn + fp) |
|
|
57 |
ppv = tp / (tp + fp) |
|
|
58 |
npv = tn / (tn + fn) |
|
|
59 |
acc = (tp+tn) / (tp+tn+fp+fn) |
|
|
60 |
|
|
|
61 |
metrics = {'threshold': threshold, |
|
|
62 |
'TPR': tpr, |
|
|
63 |
'TNR': tnr, |
|
|
64 |
'PPV': ppv, |
|
|
65 |
'NPV': npv, |
|
|
66 |
'ACC': acc} |
|
|
67 |
return {k: round(v, 3) for k,v in metrics.items()} |