Diff of /src/metrics.py [000000] .. [f45789]

Switch to side-by-side view

--- a
+++ b/src/metrics.py
@@ -0,0 +1,67 @@
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import (precision_recall_curve as pr_curve,
+                             confusion_matrix,
+                             precision_score,
+                             accuracy_score,
+                             recall_score,
+                             roc_curve,
+                             auc)
+
+def roc_auc(ground_truth, inferences, experiment_dir):
+    fpr, tpr, threshold = roc_curve(ground_truth, inferences)
+    roc_auc = auc(fpr, tpr).item()
+    plt.title('Receiver Operating Characteristic')
+    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.3f' % roc_auc)
+    plt.legend(loc = 'lower right')
+    plt.plot([0, 1], [0, 1],'r--')
+    plt.xlim([0, 1])
+    plt.ylim([0, 1])
+    plt.ylabel('True Positive Rate')
+    plt.xlabel('False Positive Rate')
+    plt.savefig(os.path.join(experiment_dir, 'roc_auc.png'))
+    plt.clf()
+    return round(roc_auc, 3)
+
+
+def pr_auc(ground_truth, inferences, experiment_dir):
+    precision, recall, threshold = pr_curve(ground_truth, inferences)
+    #precision, recall, threshold = sorted(zip(precision, recall))
+    pr_auc = auc(recall, precision).item()
+    plt.title('Precision - Recall')
+    plt.plot(recall, precision, 'b', label = 'AUC = %0.3f' % pr_auc)
+    plt.legend(loc = 'lower right')
+    plt.xlim([0, 1])
+    plt.ylim([0, 1])
+    plt.ylabel('Precision')
+    plt.xlabel('Recall')
+    plt.savefig(os.path.join(experiment_dir, 'pr_auc.png'))
+    plt.clf()
+    return round(pr_auc, 3)
+
+
+def calc_metrics(ground_truth, inferences, normalize=None, threshold=0.5):
+    y_pred = inferences > threshold
+    #accuracy = accuracy_score(y_true=ground_truth, y_pred=y_pred).item()
+    #recall = recall_score(y_true=ground_truth, y_pred=y_pred).item()
+    #precision = precision_score(y_true=ground_truth, y_pred=y_pred).item()
+
+    conf_mat = confusion_matrix(y_true=ground_truth,
+                                y_pred=y_pred,
+                                normalize=normalize)
+    tn, fp, fn, tp = list(map(lambda x: x.item(), conf_mat.ravel()))
+
+    tpr = tp / (tp + fn)
+    tnr = tn / (tn + fp)
+    ppv = tp / (tp + fp)
+    npv = tn / (tn + fn)
+    acc = (tp+tn) / (tp+tn+fp+fn)
+
+    metrics = {'threshold': threshold,
+               'TPR': tpr,
+               'TNR': tnr,
+               'PPV': ppv,
+               'NPV': npv,
+               'ACC': acc}
+    return {k: round(v, 3) for k,v in metrics.items()}
\ No newline at end of file