--- a +++ b/AICare-baselines/metrics/binary_classification_metrics.py @@ -0,0 +1,34 @@ +import torch +from torchmetrics import AUROC, Accuracy, AveragePrecision, Precision, Recall +from torchmetrics.classification import BinaryF1Score, ConfusionMatrix +import numpy as np +from sklearn import metrics as sklearn_metrics + +def minpse(preds, labels): + precisions, recalls, thresholds = sklearn_metrics.precision_recall_curve(labels, preds) + minpse_score = np.max([min(x, y) for (x, y) in zip(precisions, recalls)]) + return minpse_score + +def get_binary_metrics(preds, labels): + accuracy = Accuracy(task="binary", threshold=0.5) + auroc = AUROC(task="binary") + auprc = AveragePrecision(task="binary") + f1 = BinaryF1Score() + + # convert labels type to int + labels = labels.type(torch.int) + accuracy(preds, labels) + auroc(preds, labels) + auprc(preds, labels) + f1(preds, labels) + + minpse_score = minpse(preds, labels) + + # return a dictionary + return { + "accuracy": accuracy.compute().item(), + "auroc": auroc.compute().item(), + "auprc": auprc.compute().item(), + "f1": f1.compute().item(), + "minpse": minpse_score, + }