|
a |
|
b/run_cox_baselines.py |
|
|
1 |
# Base / Native |
|
|
2 |
import os |
|
|
3 |
import pickle |
|
|
4 |
|
|
|
5 |
# Numerical / Array |
|
|
6 |
from lifelines.utils import concordance_index |
|
|
7 |
from lifelines import CoxPHFitter |
|
|
8 |
import numpy as np |
|
|
9 |
import pandas as pd |
|
|
10 |
pd.options.display.max_rows = 999 |
|
|
11 |
|
|
|
12 |
# Env |
|
|
13 |
from utils import CI_pm |
|
|
14 |
from utils import cox_log_rank |
|
|
15 |
from utils import getCleanAllDataset, addHistomolecularSubtype |
|
|
16 |
from utils import makeKaplanMeierPlot |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def trainCox(dataroot = './data/TCGA_GBMLGG/', ckpt_name='./checkpoints/surv_15_cox/', model='cox_omic', penalizer=1e-4): |
|
|
21 |
### Creates Checkpoint Directory |
|
|
22 |
if not os.path.exists(ckpt_name): os.makedirs(ckpt_name) |
|
|
23 |
if not os.path.exists(os.path.join(ckpt_name, model)): os.makedirs(os.path.join(ckpt_name, model)) |
|
|
24 |
|
|
|
25 |
### Load PNAS Splits |
|
|
26 |
pnas_splits = pd.read_csv(dataroot+'pnas_splits.csv') |
|
|
27 |
pnas_splits.columns = ['TCGA ID']+[str(k) for k in range(1, 16)] |
|
|
28 |
pnas_splits.index = pnas_splits['TCGA ID'] |
|
|
29 |
pnas_splits = pnas_splits.drop(['TCGA ID'], axis=1) |
|
|
30 |
|
|
|
31 |
### Loads Data |
|
|
32 |
ignore_missing_moltype = True if model in ['cox_omic', 'cox_moltype', 'cox_grade+moltype', 'all'] else False |
|
|
33 |
ignore_missing_histype = True if model in ['cox_histype', 'cox_grade', 'cox_grade+moltype', 'all'] else False |
|
|
34 |
all_dataset = getCleanAllDataset(dataroot=dataroot, ignore_missing_moltype=ignore_missing_moltype, |
|
|
35 |
ignore_missing_histype=ignore_missing_histype)[1] |
|
|
36 |
model_feats = {'cox_omic':['TCGA ID', 'Histology', 'Grade', 'Molecular subtype', 'Histomolecular subtype'], |
|
|
37 |
'cox_moltype':['Survival months', 'censored', 'codeletion', 'idh mutation'], |
|
|
38 |
'cox_histype':['Survival months', 'censored', 'Histology'], |
|
|
39 |
'cox_grade':['Survival months', 'censored', 'Grade'], |
|
|
40 |
'cox_grade+moltype':['Survival months', 'censored', 'codeletion', 'idh mutation', 'Grade'], |
|
|
41 |
'cox_all':['TCGA ID', 'Histomolecular subtype']} |
|
|
42 |
cv_results = [] |
|
|
43 |
|
|
|
44 |
for k in pnas_splits.columns: |
|
|
45 |
pat_train = list(set(pnas_splits.index[pnas_splits[k] == 'Train']).intersection(all_dataset.index)) |
|
|
46 |
pat_test = list(set(pnas_splits.index[pnas_splits[k] == 'Test']).intersection(all_dataset.index)) |
|
|
47 |
feats = all_dataset.columns.drop(model_feats[model]) if model == 'cox_omic' or model == 'cox_all' else model_feats[model] |
|
|
48 |
train = all_dataset.loc[pat_train] |
|
|
49 |
test = all_dataset.loc[pat_test] |
|
|
50 |
|
|
|
51 |
cph = CoxPHFitter(penalizer=penalizer) |
|
|
52 |
cph.fit(train[feats], duration_col='Survival months', event_col='censored', show_progress=False) |
|
|
53 |
cin = concordance_index(test['Survival months'], -cph.predict_partial_hazard(test[feats]), test['censored']) |
|
|
54 |
cv_results.append(cin) |
|
|
55 |
|
|
|
56 |
train.insert(loc=0, column='Hazard', value=-cph.predict_partial_hazard(train)) |
|
|
57 |
test.insert(loc=0, column='Hazard', value=-cph.predict_partial_hazard(test)) |
|
|
58 |
pickle.dump(train, open(os.path.join(ckpt_name, model, '%s_%s_pred_train.pkl' % (model, k)), 'wb')) |
|
|
59 |
pickle.dump(test, open(os.path.join(ckpt_name, model, '%s_%s_pred_test.pkl' % (model, k)), 'wb')) |
|
|
60 |
|
|
|
61 |
pickle.dump(cv_results, open(os.path.join(ckpt_name, model, '%s_results.pkl' % model), 'wb')) |
|
|
62 |
print("C-Indices across Splits", cv_results) |
|
|
63 |
print("Average C-Index: %f" % CI_pm(cv_results)) |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
print('1. Omic Only. Ignore missing molecular subtypes') |
|
|
67 |
trainCox(model='cox_omic', penalizer=1e-1) |
|
|
68 |
print('2. molecular subtype only. Ignore missing molecular subtypes') |
|
|
69 |
trainCox(model='cox_moltype', penalizer=0) |
|
|
70 |
print('3. histology subtype only. Ignore missing histology subtypes') |
|
|
71 |
trainCox(model='cox_histype', penalizer=0) |
|
|
72 |
print('4. histologic grade only. Ignore missing histology subtypes') |
|
|
73 |
trainCox(model='cox_grade', penalizer=0) |
|
|
74 |
print('5. grade + molecular subtype. Ignore all NAs') |
|
|
75 |
trainCox(model='cox_grade+moltype', penalizer=0) |
|
|
76 |
print('6. All. Ignore all NAs') |
|
|
77 |
trainCox(model='cox_all', penalizer=1e-1) |
|
|
78 |
|
|
|
79 |
print('7. KM-Curves') |
|
|
80 |
for model in ['cox_omic', 'cox_moltype', 'cox_histype', 'cox_grade', 'cox_grade+moltype', 'cox_all']: |
|
|
81 |
makeKaplanMeierPlot(ckpt_name='./checkpoints/surv_15_cox/', model=model, split='test') |