Diff of /metrics/__init__.py [000000] .. [0f1df3]

Switch to unified view

a b/metrics/__init__.py
1
import torch
2
3
from .binary_classification_metrics import get_binary_metrics
4
from .regression_metrics import get_regression_metrics
5
from .metrics_utils import check_metric_is_better
6
7
8
def reverse_los(y, los_info):
9
    return y * los_info["los_std"] + los_info["los_mean"]
10
11
def get_all_metrics(preds, labels, task, los_info):
12
    # convert preds and labels to tensor if they are ndarray type
13
    if isinstance(preds, torch.Tensor) == False:
14
        preds = torch.tensor(preds)
15
    if isinstance(labels, torch.Tensor) == False:
16
        labels = torch.tensor(labels)
17
    
18
    if task == "outcome":
19
        return get_binary_metrics(preds, labels)
20
    elif task == "los":
21
        return get_regression_metrics(reverse_los(preds, los_info), reverse_los(labels[:, 1], los_info))
22
    elif task == "multitask":
23
        return get_binary_metrics(preds[:, 0], labels[:, 0]) | get_regression_metrics(reverse_los(preds[:, 1], los_info), reverse_los(labels[:, 1], los_info))
24
    else:
25
        raise ValueError("Task not supported")
26