|
a |
|
b/metrics/__init__.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
from .binary_classification_metrics import get_binary_metrics |
|
|
4 |
from .regression_metrics import get_regression_metrics |
|
|
5 |
from .metrics_utils import check_metric_is_better |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def reverse_los(y, los_info): |
|
|
9 |
return y * los_info["los_std"] + los_info["los_mean"] |
|
|
10 |
|
|
|
11 |
def get_all_metrics(preds, labels, task, los_info): |
|
|
12 |
# convert preds and labels to tensor if they are ndarray type |
|
|
13 |
if isinstance(preds, torch.Tensor) == False: |
|
|
14 |
preds = torch.tensor(preds) |
|
|
15 |
if isinstance(labels, torch.Tensor) == False: |
|
|
16 |
labels = torch.tensor(labels) |
|
|
17 |
|
|
|
18 |
if task == "outcome": |
|
|
19 |
return get_binary_metrics(preds, labels) |
|
|
20 |
elif task == "los": |
|
|
21 |
return get_regression_metrics(reverse_los(preds, los_info), reverse_los(labels[:, 1], los_info)) |
|
|
22 |
elif task == "multitask": |
|
|
23 |
return get_binary_metrics(preds[:, 0], labels[:, 0]) | get_regression_metrics(reverse_los(preds[:, 1], los_info), reverse_los(labels[:, 1], los_info)) |
|
|
24 |
else: |
|
|
25 |
raise ValueError("Task not supported") |
|
|
26 |
|