Diff of /run_cox_baselines.py [000000] .. [2095ed]

Switch to unified view

a b/run_cox_baselines.py
1
# Base / Native
2
import os
3
import pickle
4
5
# Numerical / Array
6
from lifelines.utils import concordance_index
7
from lifelines import CoxPHFitter
8
import numpy as np
9
import pandas as pd
10
pd.options.display.max_rows = 999
11
12
# Env
13
from utils import CI_pm
14
from utils import cox_log_rank
15
from utils import getCleanAllDataset, addHistomolecularSubtype
16
from utils import makeKaplanMeierPlot
17
18
19
20
def trainCox(dataroot = './data/TCGA_GBMLGG/', ckpt_name='./checkpoints/surv_15_cox/', model='cox_omic', penalizer=1e-4):
21
    ### Creates Checkpoint Directory
22
    if not os.path.exists(ckpt_name): os.makedirs(ckpt_name)
23
    if not os.path.exists(os.path.join(ckpt_name, model)): os.makedirs(os.path.join(ckpt_name, model))
24
    
25
    ### Load PNAS Splits
26
    pnas_splits = pd.read_csv(dataroot+'pnas_splits.csv')
27
    pnas_splits.columns = ['TCGA ID']+[str(k) for k in range(1, 16)]
28
    pnas_splits.index = pnas_splits['TCGA ID']
29
    pnas_splits = pnas_splits.drop(['TCGA ID'], axis=1)
30
    
31
    ### Loads Data
32
    ignore_missing_moltype = True if model in ['cox_omic', 'cox_moltype', 'cox_grade+moltype', 'all'] else False
33
    ignore_missing_histype = True if model in ['cox_histype', 'cox_grade', 'cox_grade+moltype', 'all'] else False
34
    all_dataset = getCleanAllDataset(dataroot=dataroot, ignore_missing_moltype=ignore_missing_moltype, 
35
                                     ignore_missing_histype=ignore_missing_histype)[1]
36
    model_feats = {'cox_omic':['TCGA ID', 'Histology', 'Grade', 'Molecular subtype', 'Histomolecular subtype'],
37
                   'cox_moltype':['Survival months', 'censored', 'codeletion', 'idh mutation'],
38
                   'cox_histype':['Survival months', 'censored', 'Histology'],
39
                   'cox_grade':['Survival months', 'censored', 'Grade'],
40
                   'cox_grade+moltype':['Survival months', 'censored', 'codeletion', 'idh mutation', 'Grade'],
41
                   'cox_all':['TCGA ID', 'Histomolecular subtype']}
42
    cv_results = []
43
44
    for k in pnas_splits.columns:
45
        pat_train = list(set(pnas_splits.index[pnas_splits[k] == 'Train']).intersection(all_dataset.index))
46
        pat_test = list(set(pnas_splits.index[pnas_splits[k] == 'Test']).intersection(all_dataset.index))
47
        feats = all_dataset.columns.drop(model_feats[model]) if model == 'cox_omic' or model == 'cox_all' else model_feats[model]
48
        train = all_dataset.loc[pat_train]
49
        test = all_dataset.loc[pat_test]
50
51
        cph = CoxPHFitter(penalizer=penalizer)
52
        cph.fit(train[feats], duration_col='Survival months', event_col='censored', show_progress=False)
53
        cin = concordance_index(test['Survival months'], -cph.predict_partial_hazard(test[feats]), test['censored'])
54
        cv_results.append(cin)
55
        
56
        train.insert(loc=0, column='Hazard', value=-cph.predict_partial_hazard(train))
57
        test.insert(loc=0, column='Hazard', value=-cph.predict_partial_hazard(test))
58
        pickle.dump(train, open(os.path.join(ckpt_name, model, '%s_%s_pred_train.pkl' % (model, k)), 'wb'))
59
        pickle.dump(test, open(os.path.join(ckpt_name, model, '%s_%s_pred_test.pkl' % (model, k)), 'wb'))
60
        
61
    pickle.dump(cv_results, open(os.path.join(ckpt_name, model, '%s_results.pkl' % model), 'wb'))
62
    print("C-Indices across Splits", cv_results)
63
    print("Average C-Index: %f" % CI_pm(cv_results))
64
65
66
print('1. Omic Only. Ignore missing molecular subtypes')
67
trainCox(model='cox_omic', penalizer=1e-1)
68
print('2. molecular subtype only. Ignore missing molecular subtypes')
69
trainCox(model='cox_moltype', penalizer=0)
70
print('3. histology subtype only. Ignore missing histology subtypes')
71
trainCox(model='cox_histype', penalizer=0)
72
print('4. histologic grade only. Ignore missing histology subtypes')
73
trainCox(model='cox_grade', penalizer=0)
74
print('5. grade + molecular subtype. Ignore all NAs')
75
trainCox(model='cox_grade+moltype', penalizer=0)
76
print('6. All. Ignore all NAs')
77
trainCox(model='cox_all', penalizer=1e-1)
78
79
print('7. KM-Curves')
80
for model in ['cox_omic', 'cox_moltype', 'cox_histype', 'cox_grade', 'cox_grade+moltype', 'cox_all']:
81
    makeKaplanMeierPlot(ckpt_name='./checkpoints/surv_15_cox/', model=model, split='test')