Switch to unified view

a b/baseline/Coxph_regression.py
1
"""
2
COX-PH Baseline
3
Leon Zheng
4
"""
5
6
import pandas as pd
7
from sksurv.linear_model import CoxPHSurvivalAnalysis
8
from sklearn.base import BaseEstimator
9
from sksurv.util import Surv
10
import numpy as np
11
from metrics import cindex
12
13
14
class CoxPhRegression(BaseEstimator):
15
16
    def __init__(self, alpha=1e-8, threshold=0.9):
17
        self.alpha = alpha
18
        self.threshold = threshold
19
        self.model = CoxPHSurvivalAnalysis(alpha=alpha)
20
21
    def set_data(self, input_train, output_train, input_test):
22
        self.input_train = input_train
23
        self.output_train = output_train
24
        self.input_test = input_test
25
26
    def fit(self, X, y):
27
        structured_y = Surv.from_dataframe('Event', 'SurvivalTime', y)
28
        self.model.fit(X, structured_y)
29
        return self
30
31
    def predict(self, X):
32
        prediction = self.model.predict_survival_function(X)
33
        y_pred = []
34
        for pred in prediction:
35
            time = pred.x
36
            survival_prob = pred.y
37
            i_pred = 0
38
            while i_pred < len(survival_prob) - 1 and survival_prob[i_pred] > self.threshold:
39
                i_pred += 1
40
            y_pred.append(time[i_pred])
41
        return pd.DataFrame(np.array([[y, np.nan] for y in y_pred]), index=X.index, columns=['SurvivalTime', 'Event'])
42
43
    def score(self, X, y):
44
        y_pred = self.predict(X)
45
        return cindex(y, y_pred)