|
a |
|
b/biobert_re/metrics.py |
|
|
1 |
try: |
|
|
2 |
from sklearn.metrics import f1_score |
|
|
3 |
|
|
|
4 |
_has_sklearn = True |
|
|
5 |
except (AttributeError, ImportError): |
|
|
6 |
_has_sklearn = False |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def is_sklearn_available(): |
|
|
10 |
return _has_sklearn |
|
|
11 |
|
|
|
12 |
if _has_sklearn: |
|
|
13 |
|
|
|
14 |
def simple_accuracy(preds, labels): |
|
|
15 |
return (preds == labels).mean() |
|
|
16 |
|
|
|
17 |
def acc_and_f1(preds, labels): |
|
|
18 |
acc = simple_accuracy(preds, labels) |
|
|
19 |
f1 = f1_score(y_true=labels, y_pred=preds) |
|
|
20 |
return { |
|
|
21 |
"acc": acc, |
|
|
22 |
"f1": f1, |
|
|
23 |
"acc_and_f1": (acc + f1) / 2, |
|
|
24 |
} |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
def glue_compute_metrics(preds, labels): |
|
|
28 |
assert len(preds) == len(labels) |
|
|
29 |
return acc_and_f1(preds, labels) |
|
|
30 |
|