[03464c]: / util / metrics.py

Download this file

50 lines (39 with data), 1.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Contain some metrics
"""
import numpy as np
# from lifelines.utils import concordance_index
# from pysurvival.utils._metrics import _concordance_index
from sksurv.metrics import concordance_index_censored
from sksurv.metrics import integrated_brier_score
def c_index(true_T, true_E, pred_risk, include_ties=True):
"""
Calculate c-index for survival prediction downstream task
"""
# Ordering true_T, true_E and pred_score in descending order according to true_T
order = np.argsort(-true_T)
true_T = true_T[order]
true_E = true_E[order]
pred_risk = pred_risk[order]
# Calculating the c-index
# result = concordance_index(true_T, -pred_risk, true_E)
# result = _concordance_index(pred_risk, true_T, true_E, include_ties)[0]
result = concordance_index_censored(true_E.astype(bool), true_T, pred_risk)[0]
return result
def ibs(true_T, true_E, pred_survival, time_points):
"""
Calculate integrated brier score for survival prediction downstream task
"""
true_E_bool = true_E.astype(bool)
true = np.array([(true_E_bool[i], true_T[i]) for i in range(len(true_E))], dtype=[('event', np.bool_), ('time', np.float32)])
# time points must be within the range of T
min_T = true_T.min()
max_T = true_T.max()
valid_index = []
for i in range(len(time_points)):
if min_T <= time_points[i] <= max_T:
valid_index.append(i)
time_points = time_points[valid_index]
pred_survival = pred_survival[:, valid_index]
result = integrated_brier_score(true, true, pred_survival, time_points)
return result