Diff of /experiments/main.py [000000] .. [a23a6e]

Switch to unified view

a b/experiments/main.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
@author: Zhi Huang
5
"""
6
7
import sys, os
8
from pathlib import Path
9
project_folder = Path("..").resolve()
10
model_folder = project_folder / "model"
11
sys.path.append(model_folder.absolute().as_posix())
12
import SALMON
13
import pandas as pd
14
import argparse
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
import torch.optim as optim
19
from torch.utils.data import DataLoader
20
from torchvision import transforms
21
from torch.autograd import Variable
22
from collections import Counter
23
import pandas as pd
24
import matplotlib.pyplot as plt
25
import math
26
import random
27
from imblearn.over_sampling import RandomOverSampler
28
from lifelines.statistics import logrank_test
29
import json
30
import tables
31
import logging
32
import csv
33
import numpy as np
34
import optunity
35
import pickle
36
import time
37
from sklearn.model_selection import KFold
38
from sklearn import preprocessing
39
import matplotlib.pyplot as plt
40
41
42
def parse_args():
43
    parser = argparse.ArgumentParser()
44
    parser.add_argument('--num_epochs', type=int, default=100, help="Number of epochs to train for. Default: 100")
45
    parser.add_argument('--measure_while_training', action='store_true', default=False, help='disables measure while training (make program faster)')
46
    parser.add_argument('--batch_size', type=int, default=256, help="Number of batches to train/test for. Default: 256")
47
    parser.add_argument('--dataset', type=int, default=7)
48
    parser.add_argument('--nocuda', action='store_true', default=False, help='disables CUDA training')
49
    parser.add_argument('--verbose', default=1, type=int)
50
    parser.add_argument('--results_dir', default=(project_folder / "experiments/Results").absolute().as_posix(), help="results dir")
51
52
    return parser.parse_args()
53
54
if __name__=='__main__':
55
    torch.cuda.empty_cache()
56
    args = parse_args()
57
    plt.ioff()
58
59
    # model file
60
    num_epochs = args.num_epochs
61
    batch_size = args.batch_size
62
    learning_rate_range = 10**np.arange(-4,-1,0.3)
63
    cuda = True
64
    verbose = 0
65
    measure_while_training = True
66
    dropout_rate = 0
67
    lambda_1 = 1e-5 # L1
68
    
69
    
70
71
    if args.dataset == 1:
72
        dataset_subset = "1_RNAseq"
73
    elif args.dataset == 2:
74
        dataset_subset = "2_miRNAseq"
75
    elif args.dataset == 3:
76
        dataset_subset = "3_RNAseq+miRNAseq"
77
    elif args.dataset == 4:
78
        dataset_subset = "4_RNAseq+miRNAseq+cnv+tmb"
79
    elif args.dataset == 5:
80
        dataset_subset = "5_RNAseq+miRNAseq+clinical"
81
    elif args.dataset == 6:
82
        dataset_subset = "6_cnv+tmb+clinical"
83
    elif args.dataset == 7:
84
        dataset_subset = "7_RNAseq+miRNAseq+cnv+tmb+clinical"
85
        
86
    datasets_5folds = pickle.load( open( (project_folder / "data/BRCA_583_new/datasets_5folds.pickle").absolute().as_posix(), "rb" ) )
87
        
88
    for i in range(5):
89
        print("5 fold CV -- %d/5" % (i+1))
90
        
91
        # dataset
92
        TIMESTRING  = time.strftime("%Y%m%d-%H.%M.%S", time.localtime())
93
        
94
        results_dir_dataset = args.results_dir + '/' + dataset_subset + '/run_' + TIMESTRING + '_fold_' + str(i+1)
95
        if not os.path.exists(results_dir_dataset):
96
            os.makedirs(results_dir_dataset)
97
            
98
        logging.basicConfig(filename=results_dir_dataset+'/mainlog.log',level=logging.DEBUG)
99
    #    print("Arguments:",args)
100
    #    logging.info("Arguments: %s" % args)
101
        datasets = datasets_5folds[str(i+1)]
102
        
103
        len_of_RNAseq = 57
104
        len_of_miRNAseq = 12
105
        len_of_cnv = 1
106
        len_of_tmb = 1
107
        len_of_clinical = 3
108
        
109
        length_of_data = {}
110
        length_of_data['mRNAseq'] = len_of_RNAseq
111
        length_of_data['miRNAseq'] = len_of_miRNAseq
112
        length_of_data['CNB'] = len_of_cnv
113
        length_of_data['TMB'] = len_of_tmb
114
        length_of_data['clinical'] = len_of_clinical
115
        
116
        if args.dataset == 1:
117
            ####      RNAseq Only
118
            datasets['train']['x'] = datasets['train']['x'][:, 0:len_of_RNAseq]
119
            datasets['test']['x'] = datasets['test']['x'][:, 0:len_of_RNAseq]
120
        elif args.dataset == 2:
121
            ####     miRNAseq Only
122
            datasets['train']['x'] = datasets['train']['x'][:, len_of_RNAseq:(len_of_RNAseq + len_of_miRNAseq)]
123
            datasets['test']['x'] = datasets['test']['x'][:, len_of_RNAseq:(len_of_RNAseq + len_of_miRNAseq)]
124
        elif args.dataset == 3:
125
            ####      RNAseq + miRNAseq
126
            datasets['train']['x'] = datasets['train']['x'][:, 0:(len_of_RNAseq + len_of_miRNAseq)]
127
            datasets['test']['x'] = datasets['test']['x'][:, 0:(len_of_RNAseq + len_of_miRNAseq)]
128
        elif args.dataset == 4:
129
            ####      RNAseq + miRNAseq + CNB + all TMB
130
            datasets['train']['x'] = datasets['train']['x'][:, 0:(len_of_RNAseq + len_of_miRNAseq + len_of_cnv + len_of_tmb)]
131
            datasets['test']['x'] = datasets['test']['x'][:, 0:(len_of_RNAseq + len_of_miRNAseq + len_of_cnv + len_of_tmb)]
132
        elif args.dataset == 5:
133
            ####      RNAseq + miRNAseq + clinical (age+ER+PR)
134
            datasets['train']['x'] = np.concatenate((datasets['train']['x'][:, 0:(len_of_RNAseq + len_of_miRNAseq)], \
135
                                        datasets['train']['x'][:, (len_of_RNAseq + len_of_miRNAseq + len_of_cnv + len_of_tmb):]),1)
136
            datasets['test']['x'] = np.concatenate((datasets['test']['x'][:, 0:(len_of_RNAseq + len_of_miRNAseq)], \
137
                                        datasets['test']['x'][:, (len_of_RNAseq + len_of_miRNAseq + len_of_cnv + len_of_tmb):]),1)
138
        elif args.dataset == 6:
139
            ####      CNB + all TMB + clinical (age+ER+PR)
140
            datasets['train']['x'] = datasets['train']['x'][:, (len_of_RNAseq + len_of_miRNAseq):]
141
            datasets['test']['x'] = datasets['test']['x'][:, (len_of_RNAseq + len_of_miRNAseq):]
142
143
        elif args.dataset == 7:
144
            ####      RNAseq + miRNAseq + CNB + all TMB + clinical (age+ER+PR)
145
            datasets['train']['x'] = datasets['train']['x']
146
            datasets['test']['x'] = datasets['test']['x']
147
    
148
    # =============================================================================
149
    # # Finding optimal learning rate w.r.t. concordance index
150
    # =============================================================================
151
        ci_list = []
152
        for j, lr in enumerate(learning_rate_range):
153
            print("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr))
154
            logging.info("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr))
155
            model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \
156
                 SALMON.train(datasets, num_epochs, batch_size, lr, dropout_rate,\
157
                                         lambda_1, length_of_data, cuda, measure_while_training, verbose)
158
        
159
            epochs_list = range(num_epochs)
160
            plt.figure(figsize=(8,4))
161
            plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1)
162
            plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1)
163
            plt.legend(['train', 'test'])
164
            plt.xlabel("epochs")
165
            plt.ylabel("Concordance index")
166
            plt.savefig(results_dir_dataset + "/convergence_%02d_lr=%.2E.png" % (j, lr),dpi=300)
167
            
168
            code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all, OS_event_test, OS_test = \
169
                SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose)
170
            ci_list.append(c_index_pred)
171
            print("current concordance index: ", c_index_pred,"\n")
172
            logging.info("current concordance index: %.10f\n" % c_index_pred)
173
            
174
        optimal_lr = learning_rate_range[np.argmax(ci_list)]
175
        
176
        print("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list)))
177
        logging.info("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list)))
178
    
179
    
180
    # =============================================================================
181
    # # Training 
182
    # =============================================================================
183
    
184
        model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \
185
                 SALMON.train(datasets, num_epochs, batch_size, optimal_lr, dropout_rate,\
186
                                         lambda_1, length_of_data, cuda, measure_while_training, verbose)
187
        code_train, loss_nn_sum, acc_train, pvalue_pred, c_index_pred, lbl_pred_all_train, OS_event_train, OS_train = \
188
            SALMON.test(model, datasets, 'train', length_of_data, batch_size, cuda, verbose)
189
        print("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
190
        logging.info("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
191
    
192
        code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all_test, OS_event_test, OS_test = \
193
            SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose)
194
        print("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
195
        logging.info("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
196
               
197
        
198
        with open(results_dir_dataset + '/model.pickle', 'wb') as handle:
199
            pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)
200
        with open(results_dir_dataset + '/c_index_list_by_epochs.pickle', 'wb') as handle:
201
            pickle.dump(c_index_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
202
            
203
        with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_train.pickle', 'wb') as handle:
204
            pickle.dump(lbl_pred_all_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
205
        with open(results_dir_dataset + '/OS_event_train.pickle', 'wb') as handle:
206
            pickle.dump(OS_event_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
207
        with open(results_dir_dataset + '/OS_train.pickle', 'wb') as handle:
208
            pickle.dump(OS_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
209
        with open(results_dir_dataset + '/code_train.pickle', 'wb') as handle:
210
            pickle.dump(code_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
211
            
212
        with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_test.pickle', 'wb') as handle:
213
            pickle.dump(lbl_pred_all_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
214
        with open(results_dir_dataset + '/OS_event_test.pickle', 'wb') as handle:
215
            pickle.dump(OS_event_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
216
        with open(results_dir_dataset + '/OS_test.pickle', 'wb') as handle:
217
            pickle.dump(OS_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
218
        with open(results_dir_dataset + '/code_test.pickle', 'wb') as handle:
219
            pickle.dump(code_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
220
    
221
    
222
        epochs_list = range(num_epochs)
223
        plt.figure(figsize=(8,4))
224
        plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1)
225
        plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1)
226
        plt.legend(['train', 'test'])
227
        plt.xlabel("epochs")
228
        plt.ylabel("Concordance index")
229
        plt.savefig(results_dir_dataset + "/convergence.png",dpi=300)
230
        
231
232