--- a +++ b/util/metrics.py @@ -0,0 +1,49 @@ +""" +Contain some metrics +""" +import numpy as np +# from lifelines.utils import concordance_index +# from pysurvival.utils._metrics import _concordance_index +from sksurv.metrics import concordance_index_censored +from sksurv.metrics import integrated_brier_score + + +def c_index(true_T, true_E, pred_risk, include_ties=True): + """ + Calculate c-index for survival prediction downstream task + """ + # Ordering true_T, true_E and pred_score in descending order according to true_T + order = np.argsort(-true_T) + + true_T = true_T[order] + true_E = true_E[order] + pred_risk = pred_risk[order] + + # Calculating the c-index + # result = concordance_index(true_T, -pred_risk, true_E) + # result = _concordance_index(pred_risk, true_T, true_E, include_ties)[0] + result = concordance_index_censored(true_E.astype(bool), true_T, pred_risk)[0] + + return result + + +def ibs(true_T, true_E, pred_survival, time_points): + """ + Calculate integrated brier score for survival prediction downstream task + """ + true_E_bool = true_E.astype(bool) + true = np.array([(true_E_bool[i], true_T[i]) for i in range(len(true_E))], dtype=[('event', np.bool_), ('time', np.float32)]) + + # time points must be within the range of T + min_T = true_T.min() + max_T = true_T.max() + valid_index = [] + for i in range(len(time_points)): + if min_T <= time_points[i] <= max_T: + valid_index.append(i) + time_points = time_points[valid_index] + pred_survival = pred_survival[:, valid_index] + + result = integrated_brier_score(true, true, pred_survival, time_points) + + return result