|
a |
|
b/survival4D/cox_reg.py |
|
|
1 |
import optunity |
|
|
2 |
import numpy as np, pandas as pd |
|
|
3 |
from lifelines import CoxPHFitter |
|
|
4 |
from lifelines.utils import concordance_index |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def train_cox_reg(xtr, ytr, penalty): |
|
|
8 |
df_tr = pd.DataFrame(np.concatenate((ytr, xtr),axis=1)) |
|
|
9 |
df_tr.columns = ['status', 'time'] + ['X'+str(i+1) for i in range(xtr.shape[1])] |
|
|
10 |
cph = CoxPHFitter(penalizer=penalty) |
|
|
11 |
cph.fit(df_tr, duration_col='time', event_col='status') |
|
|
12 |
return cph |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
# 2. 'Hyperparameter' search for Cox Regression model |
|
|
16 |
def hypersearch_cox(x_data, y_data, method, nfolds, nevals, penalty_range): |
|
|
17 |
@optunity.cross_validated(x=x_data, y=y_data, num_folds=nfolds) |
|
|
18 |
def modelrun(x_train, y_train, x_test, y_test, penalty): |
|
|
19 |
cvmod = train_cox_reg(xtr=x_train, ytr=y_train, penalty=10 ** penalty) |
|
|
20 |
cv_preds = cvmod.predict_partial_hazard(x_test) |
|
|
21 |
cv_C = concordance_index(y_test[:, 1], -cv_preds, y_test[:, 0]) |
|
|
22 |
return cv_C |
|
|
23 |
optimal_pars, searchlog, _ = optunity.maximize(modelrun, num_evals=nevals, |
|
|
24 |
solver_name=method, penalty=penalty_range) |
|
|
25 |
print('Optimal hyperparameters : ' + str(optimal_pars)) |
|
|
26 |
print('Cross-validated C after tuning: %1.3f' % searchlog.optimum) |
|
|
27 |
return optimal_pars, searchlog |