--- a +++ b/biobert_re/metrics.py @@ -0,0 +1,30 @@ +try: + from sklearn.metrics import f1_score + + _has_sklearn = True +except (AttributeError, ImportError): + _has_sklearn = False + + +def is_sklearn_available(): + return _has_sklearn + +if _has_sklearn: + + def simple_accuracy(preds, labels): + return (preds == labels).mean() + + def acc_and_f1(preds, labels): + acc = simple_accuracy(preds, labels) + f1 = f1_score(y_true=labels, y_pred=preds) + return { + "acc": acc, + "f1": f1, + "acc_and_f1": (acc + f1) / 2, + } + + + def glue_compute_metrics(preds, labels): + assert len(preds) == len(labels) + return acc_and_f1(preds, labels) +