Switch to side-by-side view

--- a
+++ b/eval/classification/binary.py
@@ -0,0 +1,84 @@
+"""
+Description: Code adapted from aicc-ognet-global/eval/metrics.py 
+"""
+
+import torch
+import numpy as np
+from sklearn.metrics import (
+    accuracy_score,
+    precision_score,
+    recall_score,
+    f1_score,
+    precision_recall_curve,
+    average_precision_score,
+    roc_auc_score
+)
+
+def get_optimal_f1(groundtruth, probabilities,
+                   return_threshold=False):
+    """Get threshold maximizing f1 score."""
+    prec, rec, threshold = precision_recall_curve(
+        groundtruth, probabilities
+    )
+
+    f1_values = 2 * (prec * rec) / (prec + rec)
+
+    argmax_f1 = np.nanargmax(f1_values)
+    max_f1 = np.nanmax(f1_values)
+
+    if return_threshold:
+        return max_f1, threshold[argmax_f1]
+    else:
+        return max_f1
+    
+
+def get_max_precision_above_recall(groundtruth, probabilities, value,
+                                   return_threshold=False):
+    """Get maximum precision such that recall >= value."""
+    if value > 1:
+        raise ValueError(f"Cannot attain a recall of {value}")
+    prec, rec, threshold = precision_recall_curve(
+        groundtruth, probabilities
+    )
+    
+    max_prec_above_rec = max(p for p, r in zip(prec, rec) if r >= value)
+
+    if return_threshold:
+        index = list(prec).index(max_prec_above_rec)
+        return max_prec_above_rec, threshold[index - 1]
+    else:
+        return max_prec_above_rec
+
+
+def get_binary_metrics(probs, labels, threshold=None):
+    if isinstance(labels, torch.Tensor):
+        labels = labels.cpu().numpy()
+    
+    if isinstance(probs, torch.Tensor):
+        probs = probs.cpu().numpy()
+
+    if threshold is None:
+        _, threshold = get_optimal_f1(labels, probs, return_threshold=True)
+    preds = (probs > threshold).astype(int)
+    acc = accuracy_score(labels, preds)
+    prec = precision_score(labels, preds)
+    rec = recall_score(labels, preds)
+    f1 = f1_score(labels, preds)
+    try:
+        auroc = roc_auc_score(labels, probs)
+    except ValueError:
+        # Catch when labels are all one class
+        auroc = 0
+    auprc = average_precision_score(labels, probs)
+    prevalence = np.mean(labels)
+
+    return {
+        'threshold': threshold,
+        'prevalence': prevalence,
+        'f1': f1,
+        f'accuracy': acc,
+        f'precision': prec,
+        f'recall': rec,
+        'auroc': auroc,
+        'auprc': auprc,
+    }