Switch to unified view

a b/AICare-baselines/metrics/es.py
1
import numpy as np
2
3
4
def calculate_confusion_matrix_value_result(outcome_pred, outcome_true):
5
    outcome_pred = 1 if outcome_pred > 0.5 else 0
6
    if outcome_pred == 1 and outcome_true == 1:
7
        return "tp"
8
    elif outcome_pred == 0 and outcome_true == 0:
9
        return "tn"
10
    elif outcome_pred == 1 and outcome_true == 0:
11
        return "fp"
12
    elif outcome_pred == 0 and outcome_true == 1:
13
        return "fn"
14
    else:
15
        raise ValueError("Unknown value occurred")
16
17
def calculate_es(los_true, threshold, case="tp"):
18
    metric = 0.0
19
    if case == "tp":
20
        if los_true >= threshold:  # predict correct in early stage
21
            metric = 1
22
        else:
23
            metric = los_true / threshold
24
    elif case == "fn":
25
        if los_true >= threshold:  # predict wrong in early stage
26
            metric = 0
27
        else:
28
            metric = los_true / threshold - 1
29
    elif case == "tn":
30
        metric = 0.0
31
    elif case == "fp":
32
        metric = -0.1 # penalty term
33
    return metric
34
35
36
def es_score(
37
    y_true_outcome,
38
    y_true_los,
39
    y_pred_outcome,
40
    threshold,
41
    verbose=0
42
):
43
    """
44
    Args:
45
        - threshold: 50%*mean_los (patient-wise) 
46
47
    Note:
48
        - y/predictions are already flattened here
49
        - so we don't need to consider visits_length
50
    """
51
    metric = []
52
    metric_optimal = []
53
    num_records = len(y_pred_outcome)
54
    for i in range(num_records):
55
        cur_outcome_pred = y_pred_outcome[i]
56
        cur_outcome_true = y_true_outcome[i]
57
        cur_los_true = y_true_los[i]
58
        prediction_result = calculate_confusion_matrix_value_result(cur_outcome_pred, cur_outcome_true)
59
        prediction_result_optimal = calculate_confusion_matrix_value_result(cur_outcome_true, cur_outcome_true)
60
        metric.append(
61
            calculate_es(
62
                cur_los_true,
63
                threshold,
64
                case=prediction_result,
65
            )
66
        )
67
        metric_optimal.append(
68
            calculate_es(
69
                cur_los_true,
70
                threshold,
71
                case=prediction_result_optimal,
72
            )
73
        )
74
    metric = np.array(metric)
75
    metric_optimal = np.array(metric_optimal)
76
    result = 0.0
77
    if metric_optimal.sum() > 0.0:
78
        result = metric.sum() / metric_optimal.sum()
79
    result = max(result, -1.0)
80
    if verbose:
81
        print("ES Score:", result)
82
    if isinstance(result, np.float64):
83
        result = result.item()
84
    return {"es": result}
85
86
if __name__ == "__main__":
87
    y_true_outcome = np.array([0,1])
88
    y_true_los = np.array([5,5])
89
    y_pred_outcome = np.array([0.7,0.7])
90
    y_pred_los = np.array([10,10])
91
    large_los = 110
92
    threshold = 10
93
    print(es_score(
94
        y_true_outcome,
95
        y_true_los,
96
        y_pred_outcome,
97
        threshold,
98
        verbose=0
99
    ))