Diff of /metrics.py [000000] .. [785f18]

Switch to unified view

a b/metrics.py
1
import numpy as np
2
import pandas as pd
3
import warnings
4
5
6
def cindex(y_true_times, predicted_times, tol=1e-8):
7
    """
8
    Author: Romuald Menuet & Rémy Dubois
9
10
    Evaluate concordance index from Pandas DataFrame, taking ties into account.
11
12
    Args:
13
        y_true_times: pd.DataFrame
14
            pd DataFrame with three columns: `PatientID`, `Event` and `SurvivalTime` the float-valued column of true survival times.
15
        predicted_times: pd.DataFrame
16
            pd DataFrame with three columns: `PatientID`, `SurvivalTime` the float-valued column of predicted survival times,
17
            and one `Event`column, whose value does not matter. It must be appended so that target and predictions have the same format.
18
        tol: float
19
            small float value for numerical stability.
20
    Returns:
21
        Concordance index, as described here:
22
        https://square.github.io/pysurvival/metrics/c_index.html
23
    """
24
25
    assert isinstance(y_true_times, pd.DataFrame), 'Y true times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns'
26
    assert isinstance(predicted_times, pd.DataFrame), 'Predicted times should be pd dataframe with patient `PatientID` as index, and `Event` and `SurvivalTime` as columns'
27
    assert len(y_true_times.shape) == 2, 'Y true times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns'
28
    assert len(predicted_times.shape) == 2, 'Predicted times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns'
29
    assert set(y_true_times.columns) == {'Event', 'SurvivalTime'}, 'Y true times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns'
30
    assert set(predicted_times.columns) == {'Event', 'SurvivalTime'}, 'Predicted times should be pd dataframe with `PatientID` as index, and `Event` and `SurvivalTime` as columns'
31
    np.testing.assert_equal(y_true_times.shape, predicted_times.shape, err_msg="Not same amount of predicted versus true samples")
32
    assert set(y_true_times.index) == set(predicted_times.index), 'Not same patients in prediction versus ground truth'
33
    assert np.all(predicted_times['SurvivalTime'] > 0), 'Predicted times should all be positive'
34
35
    events = y_true_times.Event
36
    y_true_times = y_true_times.SurvivalTime
37
    predicted_times = predicted_times.SurvivalTime
38
39
    # Just ordering the right way
40
    predicted_times = predicted_times.loc[y_true_times.index]
41
    events = events.loc[y_true_times.index]
42
43
    events = events.values.astype(int)
44
    y_true_times = y_true_times.values.astype(float)
45
    predicted_times = predicted_times.values.astype(float)
46
    # events = events.values.astype(bool)
47
48
    np.testing.assert_array_less(1.,
49
                                 predicted_times.astype(float),
50
                                 err_msg="Predicted y_true_times all below 1.\
51
                                 It should be in days. Make sure that you are not predicting risk instead of time.")
52
53
    return _cindex_np(y_true_times, predicted_times, events)
54
55
56
def _cindex_np(times, predicted_times, events, tol=1.e-8):
57
    """
58
    Raw CI computation from np arrray. Should not be used as is.
59
    """
60
    assert times.ndim == predicted_times.ndim == events.ndim == 1, "wrong input, should be vectors only"
61
    assert times.shape[0] == predicted_times.shape[0] == events.shape[0], "wrong input, should be vectors of the same len"
62
63
    risks = - predicted_times
64
65
    risks_i = risks.reshape((-1, 1))
66
    risks_j = risks.reshape((1, -1))
67
    times_i = times.reshape((-1, 1))
68
    times_j = times.reshape((1, -1))
69
    events_i = events.reshape((-1, 1))
70
71
    eligible_pairs = (times_i < times_j) * events_i
72
73
    well_ordered = np.sum(eligible_pairs * (risks_i > risks_j))
74
    ties = + np.sum(eligible_pairs * 0.5 * (risks_i == risks_j))
75
76
    return (well_ordered + ties) / (eligible_pairs.sum() + tol)