|
a |
|
b/baseline/Coxph_regression.py |
|
|
1 |
""" |
|
|
2 |
COX-PH Baseline |
|
|
3 |
Leon Zheng |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
import pandas as pd |
|
|
7 |
from sksurv.linear_model import CoxPHSurvivalAnalysis |
|
|
8 |
from sklearn.base import BaseEstimator |
|
|
9 |
from sksurv.util import Surv |
|
|
10 |
import numpy as np |
|
|
11 |
from metrics import cindex |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class CoxPhRegression(BaseEstimator): |
|
|
15 |
|
|
|
16 |
def __init__(self, alpha=1e-8, threshold=0.9): |
|
|
17 |
self.alpha = alpha |
|
|
18 |
self.threshold = threshold |
|
|
19 |
self.model = CoxPHSurvivalAnalysis(alpha=alpha) |
|
|
20 |
|
|
|
21 |
def set_data(self, input_train, output_train, input_test): |
|
|
22 |
self.input_train = input_train |
|
|
23 |
self.output_train = output_train |
|
|
24 |
self.input_test = input_test |
|
|
25 |
|
|
|
26 |
def fit(self, X, y): |
|
|
27 |
structured_y = Surv.from_dataframe('Event', 'SurvivalTime', y) |
|
|
28 |
self.model.fit(X, structured_y) |
|
|
29 |
return self |
|
|
30 |
|
|
|
31 |
def predict(self, X): |
|
|
32 |
prediction = self.model.predict_survival_function(X) |
|
|
33 |
y_pred = [] |
|
|
34 |
for pred in prediction: |
|
|
35 |
time = pred.x |
|
|
36 |
survival_prob = pred.y |
|
|
37 |
i_pred = 0 |
|
|
38 |
while i_pred < len(survival_prob) - 1 and survival_prob[i_pred] > self.threshold: |
|
|
39 |
i_pred += 1 |
|
|
40 |
y_pred.append(time[i_pred]) |
|
|
41 |
return pd.DataFrame(np.array([[y, np.nan] for y in y_pred]), index=X.index, columns=['SurvivalTime', 'Event']) |
|
|
42 |
|
|
|
43 |
def score(self, X, y): |
|
|
44 |
y_pred = self.predict(X) |
|
|
45 |
return cindex(y, y_pred) |