|
a |
|
b/metrics.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pandas as pd |
|
|
3 |
import warnings |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def cindex(y_true_times, predicted_times, tol=1e-8): |
|
|
7 |
""" |
|
|
8 |
Author: Romuald Menuet & Rémy Dubois |
|
|
9 |
|
|
|
10 |
Evaluate concordance index from Pandas DataFrame, taking ties into account. |
|
|
11 |
|
|
|
12 |
Args: |
|
|
13 |
y_true_times: pd.DataFrame |
|
|
14 |
pd DataFrame with three columns: `PatientID`, `Event` and `SurvivalTime` the float-valued column of true survival times. |
|
|
15 |
predicted_times: pd.DataFrame |
|
|
16 |
pd DataFrame with three columns: `PatientID`, `SurvivalTime` the float-valued column of predicted survival times, |
|
|
17 |
and one `Event`column, whose value does not matter. It must be appended so that target and predictions have the same format. |
|
|
18 |
tol: float |
|
|
19 |
small float value for numerical stability. |
|
|
20 |
Returns: |
|
|
21 |
Concordance index, as described here: |
|
|
22 |
https://square.github.io/pysurvival/metrics/c_index.html |
|
|
23 |
""" |
|
|
24 |
|
|
|
25 |
assert isinstance(y_true_times, pd.DataFrame), 'Y true times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns' |
|
|
26 |
assert isinstance(predicted_times, pd.DataFrame), 'Predicted times should be pd dataframe with patient `PatientID` as index, and `Event` and `SurvivalTime` as columns' |
|
|
27 |
assert len(y_true_times.shape) == 2, 'Y true times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns' |
|
|
28 |
assert len(predicted_times.shape) == 2, 'Predicted times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns' |
|
|
29 |
assert set(y_true_times.columns) == {'Event', 'SurvivalTime'}, 'Y true times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns' |
|
|
30 |
assert set(predicted_times.columns) == {'Event', 'SurvivalTime'}, 'Predicted times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns' |
|
|
31 |
np.testing.assert_equal(y_true_times.shape, predicted_times.shape, err_msg="Not same amount of predicted versus true samples") |
|
|
32 |
assert set(y_true_times.index) == set(predicted_times.index), 'Not same patients in prediction versus ground truth' |
|
|
33 |
assert np.all(predicted_times['SurvivalTime'] > 0), 'Predicted times should all be positive' |
|
|
34 |
|
|
|
35 |
events = y_true_times.Event |
|
|
36 |
y_true_times = y_true_times.SurvivalTime |
|
|
37 |
predicted_times = predicted_times.SurvivalTime |
|
|
38 |
|
|
|
39 |
# Just ordering the right way |
|
|
40 |
predicted_times = predicted_times.loc[y_true_times.index] |
|
|
41 |
events = events.loc[y_true_times.index] |
|
|
42 |
|
|
|
43 |
events = events.values.astype(int) |
|
|
44 |
y_true_times = y_true_times.values.astype(float) |
|
|
45 |
predicted_times = predicted_times.values.astype(float) |
|
|
46 |
# events = events.values.astype(bool) |
|
|
47 |
|
|
|
48 |
np.testing.assert_array_less(1., |
|
|
49 |
predicted_times.astype(float), |
|
|
50 |
err_msg="Predicted y_true_times all below 1.\ |
|
|
51 |
It should be in days. Make sure that you are not predicting risk instead of time.") |
|
|
52 |
|
|
|
53 |
return _cindex_np(y_true_times, predicted_times, events) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
def _cindex_np(times, predicted_times, events, tol=1.e-8): |
|
|
57 |
""" |
|
|
58 |
Raw CI computation from np arrray. Should not be used as is. |
|
|
59 |
""" |
|
|
60 |
assert times.ndim == predicted_times.ndim == events.ndim == 1, "wrong input, should be vectors only" |
|
|
61 |
assert times.shape[0] == predicted_times.shape[0] == events.shape[0], "wrong input, should be vectors of the same len" |
|
|
62 |
|
|
|
63 |
risks = - predicted_times |
|
|
64 |
|
|
|
65 |
risks_i = risks.reshape((-1, 1)) |
|
|
66 |
risks_j = risks.reshape((1, -1)) |
|
|
67 |
times_i = times.reshape((-1, 1)) |
|
|
68 |
times_j = times.reshape((1, -1)) |
|
|
69 |
events_i = events.reshape((-1, 1)) |
|
|
70 |
|
|
|
71 |
eligible_pairs = (times_i < times_j) * events_i |
|
|
72 |
|
|
|
73 |
well_ordered = np.sum(eligible_pairs * (risks_i > risks_j)) |
|
|
74 |
ties = + np.sum(eligible_pairs * 0.5 * (risks_i == risks_j)) |
|
|
75 |
|
|
|
76 |
return (well_ordered + ties) / (eligible_pairs.sum() + tol) |