[1de6ed]: / biobert_re / metrics.py

Download this file

31 lines (21 with data), 664 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
try:
from sklearn.metrics import f1_score
_has_sklearn = True
except (AttributeError, ImportError):
_has_sklearn = False
def is_sklearn_available():
return _has_sklearn
if _has_sklearn:
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}
def glue_compute_metrics(preds, labels):
assert len(preds) == len(labels)
return acc_and_f1(preds, labels)