Switch to unified view

a b/metrics/binary_classification_metrics.py
1
import torch
2
from torchmetrics import AUROC, Accuracy, AveragePrecision, Precision, Recall
3
from torchmetrics.classification import BinaryF1Score, ConfusionMatrix
4
import numpy as np
5
from sklearn import metrics as sklearn_metrics
6
7
def minpse(preds, labels):
8
    precisions, recalls, thresholds = sklearn_metrics.precision_recall_curve(labels, preds)
9
    minpse_score = np.max([min(x, y) for (x, y) in zip(precisions, recalls)])
10
    return minpse_score
11
12
def get_binary_metrics(preds, labels):
13
    accuracy = Accuracy(task="binary", threshold=0.5)
14
    auroc = AUROC(task="binary")
15
    auprc = AveragePrecision(task="binary")
16
    f1 = BinaryF1Score()
17
18
    # convert labels type to int
19
    labels = labels.type(torch.int)
20
    accuracy(preds, labels)
21
    auroc(preds, labels)
22
    auprc(preds, labels)
23
    f1(preds, labels)
24
25
    minpse_score = minpse(preds, labels) 
26
27
    # return a dictionary
28
    return {
29
        "accuracy": accuracy.compute().item(),
30
        "auroc": auroc.compute().item(),
31
        "auprc": auprc.compute().item(),
32
        "f1": f1.compute().item(),
33
        "minpse": minpse_score,
34
    }