[0f1df3]: / AICare-baselines / metrics / binary_classification_metrics.py

Download this file

35 lines (29 with data), 1.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from torchmetrics import AUROC, Accuracy, AveragePrecision, Precision, Recall
from torchmetrics.classification import BinaryF1Score, ConfusionMatrix
import numpy as np
from sklearn import metrics as sklearn_metrics
def minpse(preds, labels):
precisions, recalls, thresholds = sklearn_metrics.precision_recall_curve(labels, preds)
minpse_score = np.max([min(x, y) for (x, y) in zip(precisions, recalls)])
return minpse_score
def get_binary_metrics(preds, labels):
accuracy = Accuracy(task="binary", threshold=0.5)
auroc = AUROC(task="binary")
auprc = AveragePrecision(task="binary")
f1 = BinaryF1Score()
# convert labels type to int
labels = labels.type(torch.int)
accuracy(preds, labels)
auroc(preds, labels)
auprc(preds, labels)
f1(preds, labels)
minpse_score = minpse(preds, labels)
# return a dictionary
return {
"accuracy": accuracy.compute().item(),
"auroc": auroc.compute().item(),
"auprc": auprc.compute().item(),
"f1": f1.compute().item(),
"minpse": minpse_score,
}