|
a |
|
b/util/metrics.py |
|
|
1 |
""" |
|
|
2 |
Contain some metrics |
|
|
3 |
""" |
|
|
4 |
import numpy as np |
|
|
5 |
# from lifelines.utils import concordance_index |
|
|
6 |
# from pysurvival.utils._metrics import _concordance_index |
|
|
7 |
from sksurv.metrics import concordance_index_censored |
|
|
8 |
from sksurv.metrics import integrated_brier_score |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
def c_index(true_T, true_E, pred_risk, include_ties=True): |
|
|
12 |
""" |
|
|
13 |
Calculate c-index for survival prediction downstream task |
|
|
14 |
""" |
|
|
15 |
# Ordering true_T, true_E and pred_score in descending order according to true_T |
|
|
16 |
order = np.argsort(-true_T) |
|
|
17 |
|
|
|
18 |
true_T = true_T[order] |
|
|
19 |
true_E = true_E[order] |
|
|
20 |
pred_risk = pred_risk[order] |
|
|
21 |
|
|
|
22 |
# Calculating the c-index |
|
|
23 |
# result = concordance_index(true_T, -pred_risk, true_E) |
|
|
24 |
# result = _concordance_index(pred_risk, true_T, true_E, include_ties)[0] |
|
|
25 |
result = concordance_index_censored(true_E.astype(bool), true_T, pred_risk)[0] |
|
|
26 |
|
|
|
27 |
return result |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def ibs(true_T, true_E, pred_survival, time_points): |
|
|
31 |
""" |
|
|
32 |
Calculate integrated brier score for survival prediction downstream task |
|
|
33 |
""" |
|
|
34 |
true_E_bool = true_E.astype(bool) |
|
|
35 |
true = np.array([(true_E_bool[i], true_T[i]) for i in range(len(true_E))], dtype=[('event', np.bool_), ('time', np.float32)]) |
|
|
36 |
|
|
|
37 |
# time points must be within the range of T |
|
|
38 |
min_T = true_T.min() |
|
|
39 |
max_T = true_T.max() |
|
|
40 |
valid_index = [] |
|
|
41 |
for i in range(len(time_points)): |
|
|
42 |
if min_T <= time_points[i] <= max_T: |
|
|
43 |
valid_index.append(i) |
|
|
44 |
time_points = time_points[valid_index] |
|
|
45 |
pred_survival = pred_survival[:, valid_index] |
|
|
46 |
|
|
|
47 |
result = integrated_brier_score(true, true, pred_survival, time_points) |
|
|
48 |
|
|
|
49 |
return result |