a b/util/metrics.py
1
"""
2
Contain some metrics
3
"""
4
import numpy as np
5
# from lifelines.utils import concordance_index
6
# from pysurvival.utils._metrics import _concordance_index
7
from sksurv.metrics import concordance_index_censored
8
from sksurv.metrics import integrated_brier_score
9
10
11
def c_index(true_T, true_E, pred_risk, include_ties=True):
12
    """
13
    Calculate c-index for survival prediction downstream task
14
    """
15
    # Ordering true_T, true_E and pred_score in descending order according to true_T
16
    order = np.argsort(-true_T)
17
18
    true_T = true_T[order]
19
    true_E = true_E[order]
20
    pred_risk = pred_risk[order]
21
22
    # Calculating the c-index
23
    # result = concordance_index(true_T, -pred_risk, true_E)
24
    # result = _concordance_index(pred_risk, true_T, true_E, include_ties)[0]
25
    result = concordance_index_censored(true_E.astype(bool), true_T, pred_risk)[0]
26
27
    return result
28
29
30
def ibs(true_T, true_E, pred_survival, time_points):
31
    """
32
    Calculate integrated brier score for survival prediction downstream task
33
    """
34
    true_E_bool = true_E.astype(bool)
35
    true = np.array([(true_E_bool[i], true_T[i]) for i in range(len(true_E))], dtype=[('event', np.bool_), ('time', np.float32)])
36
37
    # time points must be within the range of T
38
    min_T = true_T.min()
39
    max_T = true_T.max()
40
    valid_index = []
41
    for i in range(len(time_points)):
42
        if min_T <= time_points[i] <= max_T:
43
            valid_index.append(i)
44
    time_points = time_points[valid_index]
45
    pred_survival = pred_survival[:, valid_index]
46
47
    result = integrated_brier_score(true, true, pred_survival, time_points)
48
49
    return result