Diff of /survival4D/cox_reg.py [000000] .. [2cc208]

Switch to unified view

a b/survival4D/cox_reg.py
1
import optunity
2
import numpy as np, pandas as pd
3
from lifelines import CoxPHFitter
4
from lifelines.utils import concordance_index
5
6
7
def train_cox_reg(xtr, ytr, penalty):
8
    df_tr = pd.DataFrame(np.concatenate((ytr, xtr),axis=1))
9
    df_tr.columns = ['status', 'time'] + ['X'+str(i+1) for i in range(xtr.shape[1])]
10
    cph = CoxPHFitter(penalizer=penalty)
11
    cph.fit(df_tr, duration_col='time', event_col='status')
12
    return cph
13
14
15
# 2. 'Hyperparameter' search for Cox Regression model
16
def hypersearch_cox(x_data, y_data, method, nfolds, nevals, penalty_range):
17
    @optunity.cross_validated(x=x_data, y=y_data, num_folds=nfolds)
18
    def modelrun(x_train, y_train, x_test, y_test, penalty):
19
        cvmod = train_cox_reg(xtr=x_train, ytr=y_train, penalty=10 ** penalty)
20
        cv_preds = cvmod.predict_partial_hazard(x_test)
21
        cv_C = concordance_index(y_test[:, 1], -cv_preds, y_test[:, 0])
22
        return cv_C
23
    optimal_pars, searchlog, _ = optunity.maximize(modelrun, num_evals=nevals,
24
                                                   solver_name=method, penalty=penalty_range)
25
    print('Optimal hyperparameters : ' + str(optimal_pars))
26
    print('Cross-validated C after tuning: %1.3f' % searchlog.optimum)
27
    return optimal_pars, searchlog