a b/experiments/other cancer/main_LUAD.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
@author: Zhi Huang
5
"""
6
7
import sys, os
8
sys.path.append("/home/zhihuan/Documents/SALMON/model")
9
import SALMON
10
import pandas as pd
11
import argparse
12
import torch
13
import torch.nn as nn
14
import torch.nn.functional as F
15
import torch.optim as optim
16
from torch.utils.data import DataLoader
17
from torchvision import transforms
18
from torch.autograd import Variable
19
from collections import Counter
20
import pandas as pd
21
import math
22
import random
23
from imblearn.over_sampling import RandomOverSampler
24
from lifelines.statistics import logrank_test
25
import json
26
import tables
27
import logging
28
import csv
29
import numpy as np
30
import optunity
31
import pickle
32
import time
33
from sklearn.model_selection import KFold
34
from sklearn import preprocessing
35
import matplotlib
36
matplotlib.use('Agg')
37
import matplotlib.pyplot as plt
38
plt.ioff()
39
40
41
def parse_args():
42
    parser = argparse.ArgumentParser()
43
    parser.add_argument('--dataset_dir', default='/home/zhihuan/Documents/SALMON/data/LUAD/multiomics_preprocessing_results/', help="datasets")
44
    parser.add_argument('--num_epochs', type=int, default=100, help="Number of epochs to train for. Default: 100")
45
    parser.add_argument('--measure_while_training', action='store_true', default=False, help='disables measure while training (make program faster)')
46
    parser.add_argument('--batch_size', type=int, default=64, help="Number of batches to train/test for. Default: 256")
47
    parser.add_argument('--dataset', type=int, default=7)
48
    parser.add_argument('--nocuda', action='store_true', default=False, help='disables CUDA training')
49
    parser.add_argument('--verbose', default=1, type=int)
50
    parser.add_argument('--results_dir', default='/home/zhihuan/Documents/SALMON/experiments/Results/LUAD', help="results dir")
51
    return parser.parse_args()
52
53
if __name__=='__main__':
54
    torch.cuda.empty_cache()
55
    args = parse_args()
56
57
    # model file
58
    num_epochs = args.num_epochs
59
    batch_size = args.batch_size
60
    learning_rate_range = 10**np.arange(-4,-1,0.3)
61
    cuda = True
62
    verbose = 0
63
    measure_while_training = True
64
    dropout_rate = 0
65
    lambda_1 = 1e-6 # L1
66
    
67
    # 5-fold data
68
    tempdata = {}
69
    tempdata['clinical'] = pd.read_csv(args.dataset_dir + 'clinical.csv', index_col = 0).reset_index(drop = True)
70
    tempdata['mRNAseq_eigengene'] = pd.read_csv(args.dataset_dir + 'mRNAseq_eigengene_matrix.csv', index_col = 0).reset_index(drop = True)
71
    tempdata['miRNAseq_eigengene'] = pd.read_csv(args.dataset_dir + 'miRNAseq_eigengene_matrix.csv', index_col = 0).reset_index(drop = True)
72
    tempdata['TMB'] = pd.read_csv(args.dataset_dir + 'TMB.csv', index_col = 0).reset_index(drop = True)
73
    tempdata['CNB'] = pd.read_csv(args.dataset_dir + 'CNB.csv', index_col = 0).reset_index(drop = True)
74
    tempdata['CNB']['log2_LENGTH_KB'] = np.log2(tempdata['CNB']['LENGTH_KB'].values + 1)
75
    
76
    print('0:MALE\t\t1:FEMALE\n0:Alive\t\t1:Dead')
77
    tempdata['clinical']['gender'] = (tempdata['clinical']['gender'].values == 'MALE').astype(int)
78
    tempdata['clinical']['vital_status'] = (tempdata['clinical']['vital_status'].values == 'Dead').astype(int)
79
    
80
    
81
    data = {}
82
    data['x'] = pd.concat((tempdata['mRNAseq_eigengene'], tempdata['miRNAseq_eigengene'], tempdata['CNB']['log2_LENGTH_KB'], tempdata['TMB']['All_TMB'], tempdata['clinical'][['gender','age_at_initial_pathologic_diagnosis']]), axis = 1).values.astype(np.double)
83
    all_column_names = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
84
                            ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
85
                            ['CNB', 'TMB', 'GENDER', 'AGE']
86
    print('perform min-max scaler on all input features')
87
    scaler = preprocessing.MinMaxScaler()
88
    scaler.fit(data['x'])
89
    data['x'] = scaler.transform(data['x'])
90
    
91
    data['e'] = tempdata['clinical']['vital_status'].values.astype(np.int32)
92
    data['t'] = tempdata['clinical']['survival_days'].values.astype(np.double)
93
    
94
    if args.dataset == 1:
95
        dataset_subset = "1_RNAseq"
96
        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])]
97
        
98
    elif args.dataset == 2:
99
        dataset_subset = "2_miRNAseq"
100
        data['column_names'] = ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])]
101
        
102
    elif args.dataset == 3:
103
        dataset_subset = "3_RNAseq+miRNAseq"
104
        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
105
                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])]
106
    elif args.dataset == 4:
107
        dataset_subset = "4_RNAseq+miRNAseq+cnb+tmb"
108
        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
109
                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
110
                                ['CNB', 'TMB']
111
    elif args.dataset == 5:
112
        dataset_subset = "5_RNAseq+miRNAseq+clinical"
113
        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
114
                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
115
                                ['GENDER', 'AGE']
116
    elif args.dataset == 6:
117
        dataset_subset = "6_cnb+tmb+clinical"
118
        data['column_names'] = ['CNB', 'TMB', 'GENDER', 'AGE']
119
        
120
    elif args.dataset == 7:
121
        dataset_subset = "7_RNAseq+miRNAseq+cnb+tmb+clinical"
122
        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
123
                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
124
                                ['CNB', 'TMB', 'GENDER', 'AGE']
125
    print('subsetting data...')
126
    data['x'] = data['x'][:, [i for i, c in enumerate(all_column_names) if c in data['column_names']]]
127
128
    kf = KFold(n_splits=5, shuffle=True, random_state=666)
129
    datasets_5folds = {}
130
    for ix, (train_index, test_index) in enumerate(kf.split(data['x']), start = 1):
131
        datasets_5folds[ix] = {}
132
        datasets_5folds[ix]['train'] = {}
133
        datasets_5folds[ix]['train']['x'] = data['x'][train_index, :]
134
        datasets_5folds[ix]['train']['e'] = data['e'][train_index]
135
        datasets_5folds[ix]['train']['t'] = data['t'][train_index]
136
        datasets_5folds[ix]['test'] = {}
137
        datasets_5folds[ix]['test']['x'] = data['x'][train_index, :]
138
        datasets_5folds[ix]['test']['e'] = data['e'][train_index]
139
        datasets_5folds[ix]['test']['t'] = data['t'][train_index]
140
141
    for i in range(1, len(datasets_5folds) + 1):
142
        print("5 fold CV -- %d/5" % i)
143
        
144
        # dataset
145
        TIMESTRING  = time.strftime("%Y%m%d-%H.%M.%S", time.localtime())
146
        
147
        results_dir_dataset = args.results_dir + '/' + dataset_subset + '/run_' + TIMESTRING + '_fold_' + str(i)
148
        if not os.path.exists(results_dir_dataset):
149
            os.makedirs(results_dir_dataset)
150
            
151
        logging.basicConfig(filename=results_dir_dataset+'/mainlog.log',level=logging.DEBUG)
152
    #    print("Arguments:",args)
153
    #    logging.info("Arguments: %s" % args)
154
        datasets = datasets_5folds[i]
155
        
156
        length_of_data = {}
157
        length_of_data['mRNAseq'] = tempdata['mRNAseq_eigengene'].shape[1]
158
        length_of_data['miRNAseq'] = tempdata['miRNAseq_eigengene'].shape[1]
159
        length_of_data['CNB'] = 1
160
        length_of_data['TMB'] = 1
161
        length_of_data['clinical'] = 2
162
        
163
    # =============================================================================
164
    # # Finding optimal learning rate w.r.t. concordance index
165
    # =============================================================================
166
        ci_list = []
167
        for j, lr in enumerate(learning_rate_range):
168
            print("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr))
169
            logging.info("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr))
170
            model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \
171
                 SALMON.train(datasets, num_epochs, batch_size, lr, dropout_rate,\
172
                                         lambda_1, length_of_data, cuda, measure_while_training, verbose)
173
        
174
            epochs_list = range(num_epochs)
175
            plt.figure(figsize=(8,4))
176
            plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1)
177
            plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1)
178
            plt.legend(['train', 'test'])
179
            plt.xlabel("epochs")
180
            plt.ylabel("Concordance index")
181
            plt.savefig(results_dir_dataset + "/convergence_%02d_lr=%.2E.png" % (j, lr),dpi=300)
182
            plt.close()
183
            code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all, OS_event_test, OS_test = \
184
                SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose)
185
            ci_list.append(c_index_pred)
186
            print("current concordance index: ", c_index_pred,"\n")
187
            logging.info("current concordance index: %.10f\n" % c_index_pred)
188
            
189
        optimal_lr = learning_rate_range[np.argmax(ci_list)]
190
        
191
        print("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list)))
192
        logging.info("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list)))
193
    
194
    
195
    # =============================================================================
196
    # # Training 
197
    # =============================================================================
198
    
199
        model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \
200
                 SALMON.train(datasets, num_epochs, batch_size, optimal_lr, dropout_rate,\
201
                                         lambda_1, length_of_data, cuda, measure_while_training, verbose)
202
        code_train, loss_nn_sum, acc_train, pvalue_pred, c_index_pred, lbl_pred_all_train, OS_event_train, OS_train = \
203
            SALMON.test(model, datasets, 'train', length_of_data, batch_size, cuda, verbose)
204
        print("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
205
        logging.info("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
206
    
207
        code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all_test, OS_event_test, OS_test = \
208
            SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose)
209
        print("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
210
        logging.info("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
211
               
212
        
213
        with open(results_dir_dataset + '/model.pickle', 'wb') as handle:
214
            pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)
215
        with open(results_dir_dataset + '/c_index_list_by_epochs.pickle', 'wb') as handle:
216
            pickle.dump(c_index_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
217
            
218
        with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_train.pickle', 'wb') as handle:
219
            pickle.dump(lbl_pred_all_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
220
        with open(results_dir_dataset + '/OS_event_train.pickle', 'wb') as handle:
221
            pickle.dump(OS_event_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
222
        with open(results_dir_dataset + '/OS_train.pickle', 'wb') as handle:
223
            pickle.dump(OS_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
224
        with open(results_dir_dataset + '/code_train.pickle', 'wb') as handle:
225
            pickle.dump(code_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
226
            
227
        with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_test.pickle', 'wb') as handle:
228
            pickle.dump(lbl_pred_all_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
229
        with open(results_dir_dataset + '/OS_event_test.pickle', 'wb') as handle:
230
            pickle.dump(OS_event_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
231
        with open(results_dir_dataset + '/OS_test.pickle', 'wb') as handle:
232
            pickle.dump(OS_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
233
        with open(results_dir_dataset + '/code_test.pickle', 'wb') as handle:
234
            pickle.dump(code_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
235
    
236
    
237
        epochs_list = range(num_epochs)
238
        plt.figure(figsize=(8,4))
239
        plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1)
240
        plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1)
241
        plt.legend(['train', 'test'])
242
        plt.xlabel("epochs")
243
        plt.ylabel("Concordance index")
244
        plt.savefig(results_dir_dataset + "/convergence.png",dpi=300)
245
        plt.close()
246
247