Diff of /util/metrics.py [000000] .. [03464c]

Switch to side-by-side view

--- a
+++ b/util/metrics.py
@@ -0,0 +1,49 @@
+"""
+Contain some metrics
+"""
+import numpy as np
+# from lifelines.utils import concordance_index
+# from pysurvival.utils._metrics import _concordance_index
+from sksurv.metrics import concordance_index_censored
+from sksurv.metrics import integrated_brier_score
+
+
+def c_index(true_T, true_E, pred_risk, include_ties=True):
+    """
+    Calculate c-index for survival prediction downstream task
+    """
+    # Ordering true_T, true_E and pred_score in descending order according to true_T
+    order = np.argsort(-true_T)
+
+    true_T = true_T[order]
+    true_E = true_E[order]
+    pred_risk = pred_risk[order]
+
+    # Calculating the c-index
+    # result = concordance_index(true_T, -pred_risk, true_E)
+    # result = _concordance_index(pred_risk, true_T, true_E, include_ties)[0]
+    result = concordance_index_censored(true_E.astype(bool), true_T, pred_risk)[0]
+
+    return result
+
+
+def ibs(true_T, true_E, pred_survival, time_points):
+    """
+    Calculate integrated brier score for survival prediction downstream task
+    """
+    true_E_bool = true_E.astype(bool)
+    true = np.array([(true_E_bool[i], true_T[i]) for i in range(len(true_E))], dtype=[('event', np.bool_), ('time', np.float32)])
+
+    # time points must be within the range of T
+    min_T = true_T.min()
+    max_T = true_T.max()
+    valid_index = []
+    for i in range(len(time_points)):
+        if min_T <= time_points[i] <= max_T:
+            valid_index.append(i)
+    time_points = time_points[valid_index]
+    pred_survival = pred_survival[:, valid_index]
+
+    result = integrated_brier_score(true, true, pred_survival, time_points)
+
+    return result