|
a |
|
b/coxnet/coxnet_baseline.py |
|
|
1 |
""" |
|
|
2 |
Coxnet: CoxPH with Lasso penalty. On Owkin dataset. |
|
|
3 |
Leon Zheng |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
import preprocessing |
|
|
7 |
from sksurv.linear_model import CoxnetSurvivalAnalysis |
|
|
8 |
from sklearn.model_selection import cross_validate, RandomizedSearchCV |
|
|
9 |
from sksurv.util import Surv |
|
|
10 |
import numpy as np |
|
|
11 |
import pandas as pd |
|
|
12 |
|
|
|
13 |
# Features |
|
|
14 |
# features = ['original_shape_Compactness2', |
|
|
15 |
# 'original_shape_SphericalDisproportion', |
|
|
16 |
# 'original_shape_SurfaceVolumeRatio', |
|
|
17 |
# 'original_firstorder_Kurtosis', |
|
|
18 |
# 'original_firstorder_MeanAbsoluteDeviation', |
|
|
19 |
# 'original_firstorder_Minimum', |
|
|
20 |
# 'original_glcm_ClusterProminence', |
|
|
21 |
# 'original_glcm_Contrast', |
|
|
22 |
# 'original_glcm_DifferenceEntropy', |
|
|
23 |
# 'original_glcm_DifferenceAverage', |
|
|
24 |
# 'original_glcm_JointEnergy', |
|
|
25 |
# 'original_glcm_Id', |
|
|
26 |
# 'original_glcm_Idm', |
|
|
27 |
# 'original_glcm_Imc1', |
|
|
28 |
# 'original_glcm_Imc2', |
|
|
29 |
# 'original_glcm_Idmn', |
|
|
30 |
# 'original_glcm_Idn', |
|
|
31 |
# 'original_glrlm_ShortRunEmphasis', |
|
|
32 |
# 'original_glrlm_LongRunEmphasis', |
|
|
33 |
# 'original_glrlm_GrayLevelNonUniformity', |
|
|
34 |
# 'original_glrlm_RunPercentage', |
|
|
35 |
# 'original_glrlm_ShortRunLowGrayLevelEmphasis', |
|
|
36 |
# 'original_glrlm_LongRunLowGrayLevelEmphasis', |
|
|
37 |
# 'original_glrlm_LongRunHighGrayLevelEmphasis', |
|
|
38 |
# 'Nstage', |
|
|
39 |
# 'age', |
|
|
40 |
# 'SourceDataset'] |
|
|
41 |
|
|
|
42 |
# radiomics_features = ['original_shape_Sphericity', 'original_shape_SurfaceVolumeRatio', |
|
|
43 |
# 'original_shape_Maximum3DDiameter', 'original_glcm_JointEntropy', 'original_glcm_Id', |
|
|
44 |
# 'original_glcm_Idm'] |
|
|
45 |
# clinical_features = ['SourceDataset', 'Nstage'] |
|
|
46 |
# features = radiomics_features + clinical_features |
|
|
47 |
|
|
|
48 |
features = ['Mstage', |
|
|
49 |
'Nstage', |
|
|
50 |
'SourceDataset', |
|
|
51 |
'age', |
|
|
52 |
'original_shape_VoxelVolume', |
|
|
53 |
'original_firstorder_Maximum', |
|
|
54 |
'original_firstorder_Mean', |
|
|
55 |
'original_glcm_ClusterProminence', |
|
|
56 |
'original_glcm_Idm', |
|
|
57 |
'original_glcm_Idn', |
|
|
58 |
'original_glrlm_RunPercentage'] |
|
|
59 |
|
|
|
60 |
# Read data |
|
|
61 |
input_train, output_train, input_test = preprocessing.load_owkin_data() |
|
|
62 |
input_train = input_train[features] |
|
|
63 |
input_test = input_test[features] |
|
|
64 |
input_train, input_test = preprocessing.normalizing_input(input_train, input_test) |
|
|
65 |
structured_y = Surv.from_dataframe('Event', 'SurvivalTime', output_train) |
|
|
66 |
|
|
|
67 |
# Coxnet |
|
|
68 |
# coxnet = CoxnetSurvivalAnalysis() |
|
|
69 |
# print(cross_validate(coxnet, input_train, structured_y, cv=5)) |
|
|
70 |
|
|
|
71 |
# Grid search |
|
|
72 |
tuned_params = {"l1_ratio": np.linspace(0.01, 0.02, 100), |
|
|
73 |
"n_alphas": range(140, 160, 1), |
|
|
74 |
} |
|
|
75 |
grid_search = RandomizedSearchCV(CoxnetSurvivalAnalysis(), tuned_params, cv=5, n_jobs=4, n_iter=1000) |
|
|
76 |
grid_search.fit(input_train, structured_y) |
|
|
77 |
print(grid_search.best_score_) |
|
|
78 |
best_params = grid_search.best_params_ |
|
|
79 |
print(best_params) |
|
|
80 |
|
|
|
81 |
# Prediction |
|
|
82 |
def predict(model, X, threshold=0.9): |
|
|
83 |
prediction = model.predict_survival_function(X) |
|
|
84 |
y_pred = [] |
|
|
85 |
for pred in prediction: |
|
|
86 |
time = pred.x |
|
|
87 |
survival_prob = pred.y |
|
|
88 |
i_pred = 0 |
|
|
89 |
while i_pred < len(survival_prob) - 1 and survival_prob[i_pred] > threshold: |
|
|
90 |
i_pred += 1 |
|
|
91 |
y_pred.append(time[i_pred]) |
|
|
92 |
return pd.DataFrame(np.array([[y, np.nan] for y in y_pred]), index=X.index, columns=['SurvivalTime', 'Event']) |
|
|
93 |
|
|
|
94 |
coxph = CoxnetSurvivalAnalysis(**best_params, fit_baseline_model=True) |
|
|
95 |
coxph.fit(input_train, structured_y) |
|
|
96 |
y_pred = predict(coxph, input_test) |
|
|
97 |
y_pred.to_csv('submission.csv') |