|
a |
|
b/metrics/binary_classification_metrics.py |
|
|
1 |
import torch |
|
|
2 |
from torchmetrics import AUROC, Accuracy, AveragePrecision, Precision, Recall |
|
|
3 |
from torchmetrics.classification import BinaryF1Score, ConfusionMatrix |
|
|
4 |
import numpy as np |
|
|
5 |
from sklearn import metrics as sklearn_metrics |
|
|
6 |
|
|
|
7 |
def minpse(preds, labels): |
|
|
8 |
precisions, recalls, thresholds = sklearn_metrics.precision_recall_curve(labels, preds) |
|
|
9 |
minpse_score = np.max([min(x, y) for (x, y) in zip(precisions, recalls)]) |
|
|
10 |
return minpse_score |
|
|
11 |
|
|
|
12 |
def get_binary_metrics(preds, labels): |
|
|
13 |
accuracy = Accuracy(task="binary", threshold=0.5) |
|
|
14 |
auroc = AUROC(task="binary") |
|
|
15 |
auprc = AveragePrecision(task="binary") |
|
|
16 |
f1 = BinaryF1Score() |
|
|
17 |
|
|
|
18 |
# convert labels type to int |
|
|
19 |
labels = labels.type(torch.int) |
|
|
20 |
accuracy(preds, labels) |
|
|
21 |
auroc(preds, labels) |
|
|
22 |
auprc(preds, labels) |
|
|
23 |
f1(preds, labels) |
|
|
24 |
|
|
|
25 |
minpse_score = minpse(preds, labels) |
|
|
26 |
|
|
|
27 |
# return a dictionary |
|
|
28 |
return { |
|
|
29 |
"accuracy": accuracy.compute().item(), |
|
|
30 |
"auroc": auroc.compute().item(), |
|
|
31 |
"auprc": auprc.compute().item(), |
|
|
32 |
"f1": f1.compute().item(), |
|
|
33 |
"minpse": minpse_score, |
|
|
34 |
} |