|
a |
|
b/experiments/other cancer/main_LUAD.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
""" |
|
|
4 |
@author: Zhi Huang |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import sys, os |
|
|
8 |
sys.path.append("/home/zhihuan/Documents/SALMON/model") |
|
|
9 |
import SALMON |
|
|
10 |
import pandas as pd |
|
|
11 |
import argparse |
|
|
12 |
import torch |
|
|
13 |
import torch.nn as nn |
|
|
14 |
import torch.nn.functional as F |
|
|
15 |
import torch.optim as optim |
|
|
16 |
from torch.utils.data import DataLoader |
|
|
17 |
from torchvision import transforms |
|
|
18 |
from torch.autograd import Variable |
|
|
19 |
from collections import Counter |
|
|
20 |
import pandas as pd |
|
|
21 |
import math |
|
|
22 |
import random |
|
|
23 |
from imblearn.over_sampling import RandomOverSampler |
|
|
24 |
from lifelines.statistics import logrank_test |
|
|
25 |
import json |
|
|
26 |
import tables |
|
|
27 |
import logging |
|
|
28 |
import csv |
|
|
29 |
import numpy as np |
|
|
30 |
import optunity |
|
|
31 |
import pickle |
|
|
32 |
import time |
|
|
33 |
from sklearn.model_selection import KFold |
|
|
34 |
from sklearn import preprocessing |
|
|
35 |
import matplotlib |
|
|
36 |
matplotlib.use('Agg') |
|
|
37 |
import matplotlib.pyplot as plt |
|
|
38 |
plt.ioff() |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def parse_args(): |
|
|
42 |
parser = argparse.ArgumentParser() |
|
|
43 |
parser.add_argument('--dataset_dir', default='/home/zhihuan/Documents/SALMON/data/LUAD/multiomics_preprocessing_results/', help="datasets") |
|
|
44 |
parser.add_argument('--num_epochs', type=int, default=100, help="Number of epochs to train for. Default: 100") |
|
|
45 |
parser.add_argument('--measure_while_training', action='store_true', default=False, help='disables measure while training (make program faster)') |
|
|
46 |
parser.add_argument('--batch_size', type=int, default=64, help="Number of batches to train/test for. Default: 256") |
|
|
47 |
parser.add_argument('--dataset', type=int, default=7) |
|
|
48 |
parser.add_argument('--nocuda', action='store_true', default=False, help='disables CUDA training') |
|
|
49 |
parser.add_argument('--verbose', default=1, type=int) |
|
|
50 |
parser.add_argument('--results_dir', default='/home/zhihuan/Documents/SALMON/experiments/Results/LUAD', help="results dir") |
|
|
51 |
return parser.parse_args() |
|
|
52 |
|
|
|
53 |
if __name__=='__main__': |
|
|
54 |
torch.cuda.empty_cache() |
|
|
55 |
args = parse_args() |
|
|
56 |
|
|
|
57 |
# model file |
|
|
58 |
num_epochs = args.num_epochs |
|
|
59 |
batch_size = args.batch_size |
|
|
60 |
learning_rate_range = 10**np.arange(-4,-1,0.3) |
|
|
61 |
cuda = True |
|
|
62 |
verbose = 0 |
|
|
63 |
measure_while_training = True |
|
|
64 |
dropout_rate = 0 |
|
|
65 |
lambda_1 = 1e-6 # L1 |
|
|
66 |
|
|
|
67 |
# 5-fold data |
|
|
68 |
tempdata = {} |
|
|
69 |
tempdata['clinical'] = pd.read_csv(args.dataset_dir + 'clinical.csv', index_col = 0).reset_index(drop = True) |
|
|
70 |
tempdata['mRNAseq_eigengene'] = pd.read_csv(args.dataset_dir + 'mRNAseq_eigengene_matrix.csv', index_col = 0).reset_index(drop = True) |
|
|
71 |
tempdata['miRNAseq_eigengene'] = pd.read_csv(args.dataset_dir + 'miRNAseq_eigengene_matrix.csv', index_col = 0).reset_index(drop = True) |
|
|
72 |
tempdata['TMB'] = pd.read_csv(args.dataset_dir + 'TMB.csv', index_col = 0).reset_index(drop = True) |
|
|
73 |
tempdata['CNB'] = pd.read_csv(args.dataset_dir + 'CNB.csv', index_col = 0).reset_index(drop = True) |
|
|
74 |
tempdata['CNB']['log2_LENGTH_KB'] = np.log2(tempdata['CNB']['LENGTH_KB'].values + 1) |
|
|
75 |
|
|
|
76 |
print('0:MALE\t\t1:FEMALE\n0:Alive\t\t1:Dead') |
|
|
77 |
tempdata['clinical']['gender'] = (tempdata['clinical']['gender'].values == 'MALE').astype(int) |
|
|
78 |
tempdata['clinical']['vital_status'] = (tempdata['clinical']['vital_status'].values == 'Dead').astype(int) |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
data = {} |
|
|
82 |
data['x'] = pd.concat((tempdata['mRNAseq_eigengene'], tempdata['miRNAseq_eigengene'], tempdata['CNB']['log2_LENGTH_KB'], tempdata['TMB']['All_TMB'], tempdata['clinical'][['gender','age_at_initial_pathologic_diagnosis']]), axis = 1).values.astype(np.double) |
|
|
83 |
all_column_names = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \ |
|
|
84 |
['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \ |
|
|
85 |
['CNB', 'TMB', 'GENDER', 'AGE'] |
|
|
86 |
print('perform min-max scaler on all input features') |
|
|
87 |
scaler = preprocessing.MinMaxScaler() |
|
|
88 |
scaler.fit(data['x']) |
|
|
89 |
data['x'] = scaler.transform(data['x']) |
|
|
90 |
|
|
|
91 |
data['e'] = tempdata['clinical']['vital_status'].values.astype(np.int32) |
|
|
92 |
data['t'] = tempdata['clinical']['survival_days'].values.astype(np.double) |
|
|
93 |
|
|
|
94 |
if args.dataset == 1: |
|
|
95 |
dataset_subset = "1_RNAseq" |
|
|
96 |
data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] |
|
|
97 |
|
|
|
98 |
elif args.dataset == 2: |
|
|
99 |
dataset_subset = "2_miRNAseq" |
|
|
100 |
data['column_names'] = ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] |
|
|
101 |
|
|
|
102 |
elif args.dataset == 3: |
|
|
103 |
dataset_subset = "3_RNAseq+miRNAseq" |
|
|
104 |
data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \ |
|
|
105 |
['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] |
|
|
106 |
elif args.dataset == 4: |
|
|
107 |
dataset_subset = "4_RNAseq+miRNAseq+cnb+tmb" |
|
|
108 |
data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \ |
|
|
109 |
['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \ |
|
|
110 |
['CNB', 'TMB'] |
|
|
111 |
elif args.dataset == 5: |
|
|
112 |
dataset_subset = "5_RNAseq+miRNAseq+clinical" |
|
|
113 |
data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \ |
|
|
114 |
['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \ |
|
|
115 |
['GENDER', 'AGE'] |
|
|
116 |
elif args.dataset == 6: |
|
|
117 |
dataset_subset = "6_cnb+tmb+clinical" |
|
|
118 |
data['column_names'] = ['CNB', 'TMB', 'GENDER', 'AGE'] |
|
|
119 |
|
|
|
120 |
elif args.dataset == 7: |
|
|
121 |
dataset_subset = "7_RNAseq+miRNAseq+cnb+tmb+clinical" |
|
|
122 |
data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \ |
|
|
123 |
['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \ |
|
|
124 |
['CNB', 'TMB', 'GENDER', 'AGE'] |
|
|
125 |
print('subsetting data...') |
|
|
126 |
data['x'] = data['x'][:, [i for i, c in enumerate(all_column_names) if c in data['column_names']]] |
|
|
127 |
|
|
|
128 |
kf = KFold(n_splits=5, shuffle=True, random_state=666) |
|
|
129 |
datasets_5folds = {} |
|
|
130 |
for ix, (train_index, test_index) in enumerate(kf.split(data['x']), start = 1): |
|
|
131 |
datasets_5folds[ix] = {} |
|
|
132 |
datasets_5folds[ix]['train'] = {} |
|
|
133 |
datasets_5folds[ix]['train']['x'] = data['x'][train_index, :] |
|
|
134 |
datasets_5folds[ix]['train']['e'] = data['e'][train_index] |
|
|
135 |
datasets_5folds[ix]['train']['t'] = data['t'][train_index] |
|
|
136 |
datasets_5folds[ix]['test'] = {} |
|
|
137 |
datasets_5folds[ix]['test']['x'] = data['x'][train_index, :] |
|
|
138 |
datasets_5folds[ix]['test']['e'] = data['e'][train_index] |
|
|
139 |
datasets_5folds[ix]['test']['t'] = data['t'][train_index] |
|
|
140 |
|
|
|
141 |
for i in range(1, len(datasets_5folds) + 1): |
|
|
142 |
print("5 fold CV -- %d/5" % i) |
|
|
143 |
|
|
|
144 |
# dataset |
|
|
145 |
TIMESTRING = time.strftime("%Y%m%d-%H.%M.%S", time.localtime()) |
|
|
146 |
|
|
|
147 |
results_dir_dataset = args.results_dir + '/' + dataset_subset + '/run_' + TIMESTRING + '_fold_' + str(i) |
|
|
148 |
if not os.path.exists(results_dir_dataset): |
|
|
149 |
os.makedirs(results_dir_dataset) |
|
|
150 |
|
|
|
151 |
logging.basicConfig(filename=results_dir_dataset+'/mainlog.log',level=logging.DEBUG) |
|
|
152 |
# print("Arguments:",args) |
|
|
153 |
# logging.info("Arguments: %s" % args) |
|
|
154 |
datasets = datasets_5folds[i] |
|
|
155 |
|
|
|
156 |
length_of_data = {} |
|
|
157 |
length_of_data['mRNAseq'] = tempdata['mRNAseq_eigengene'].shape[1] |
|
|
158 |
length_of_data['miRNAseq'] = tempdata['miRNAseq_eigengene'].shape[1] |
|
|
159 |
length_of_data['CNB'] = 1 |
|
|
160 |
length_of_data['TMB'] = 1 |
|
|
161 |
length_of_data['clinical'] = 2 |
|
|
162 |
|
|
|
163 |
# ============================================================================= |
|
|
164 |
# # Finding optimal learning rate w.r.t. concordance index |
|
|
165 |
# ============================================================================= |
|
|
166 |
ci_list = [] |
|
|
167 |
for j, lr in enumerate(learning_rate_range): |
|
|
168 |
print("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr)) |
|
|
169 |
logging.info("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr)) |
|
|
170 |
model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \ |
|
|
171 |
SALMON.train(datasets, num_epochs, batch_size, lr, dropout_rate,\ |
|
|
172 |
lambda_1, length_of_data, cuda, measure_while_training, verbose) |
|
|
173 |
|
|
|
174 |
epochs_list = range(num_epochs) |
|
|
175 |
plt.figure(figsize=(8,4)) |
|
|
176 |
plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1) |
|
|
177 |
plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1) |
|
|
178 |
plt.legend(['train', 'test']) |
|
|
179 |
plt.xlabel("epochs") |
|
|
180 |
plt.ylabel("Concordance index") |
|
|
181 |
plt.savefig(results_dir_dataset + "/convergence_%02d_lr=%.2E.png" % (j, lr),dpi=300) |
|
|
182 |
plt.close() |
|
|
183 |
code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all, OS_event_test, OS_test = \ |
|
|
184 |
SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose) |
|
|
185 |
ci_list.append(c_index_pred) |
|
|
186 |
print("current concordance index: ", c_index_pred,"\n") |
|
|
187 |
logging.info("current concordance index: %.10f\n" % c_index_pred) |
|
|
188 |
|
|
|
189 |
optimal_lr = learning_rate_range[np.argmax(ci_list)] |
|
|
190 |
|
|
|
191 |
print("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list))) |
|
|
192 |
logging.info("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list))) |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
# ============================================================================= |
|
|
196 |
# # Training |
|
|
197 |
# ============================================================================= |
|
|
198 |
|
|
|
199 |
model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \ |
|
|
200 |
SALMON.train(datasets, num_epochs, batch_size, optimal_lr, dropout_rate,\ |
|
|
201 |
lambda_1, length_of_data, cuda, measure_while_training, verbose) |
|
|
202 |
code_train, loss_nn_sum, acc_train, pvalue_pred, c_index_pred, lbl_pred_all_train, OS_event_train, OS_train = \ |
|
|
203 |
SALMON.test(model, datasets, 'train', length_of_data, batch_size, cuda, verbose) |
|
|
204 |
print("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred)) |
|
|
205 |
logging.info("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred)) |
|
|
206 |
|
|
|
207 |
code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all_test, OS_event_test, OS_test = \ |
|
|
208 |
SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose) |
|
|
209 |
print("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred)) |
|
|
210 |
logging.info("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred)) |
|
|
211 |
|
|
|
212 |
|
|
|
213 |
with open(results_dir_dataset + '/model.pickle', 'wb') as handle: |
|
|
214 |
pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
215 |
with open(results_dir_dataset + '/c_index_list_by_epochs.pickle', 'wb') as handle: |
|
|
216 |
pickle.dump(c_index_list, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
217 |
|
|
|
218 |
with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_train.pickle', 'wb') as handle: |
|
|
219 |
pickle.dump(lbl_pred_all_train, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
220 |
with open(results_dir_dataset + '/OS_event_train.pickle', 'wb') as handle: |
|
|
221 |
pickle.dump(OS_event_train, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
222 |
with open(results_dir_dataset + '/OS_train.pickle', 'wb') as handle: |
|
|
223 |
pickle.dump(OS_train, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
224 |
with open(results_dir_dataset + '/code_train.pickle', 'wb') as handle: |
|
|
225 |
pickle.dump(code_train, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
226 |
|
|
|
227 |
with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_test.pickle', 'wb') as handle: |
|
|
228 |
pickle.dump(lbl_pred_all_test, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
229 |
with open(results_dir_dataset + '/OS_event_test.pickle', 'wb') as handle: |
|
|
230 |
pickle.dump(OS_event_test, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
231 |
with open(results_dir_dataset + '/OS_test.pickle', 'wb') as handle: |
|
|
232 |
pickle.dump(OS_test, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
233 |
with open(results_dir_dataset + '/code_test.pickle', 'wb') as handle: |
|
|
234 |
pickle.dump(code_test, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
235 |
|
|
|
236 |
|
|
|
237 |
epochs_list = range(num_epochs) |
|
|
238 |
plt.figure(figsize=(8,4)) |
|
|
239 |
plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1) |
|
|
240 |
plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1) |
|
|
241 |
plt.legend(['train', 'test']) |
|
|
242 |
plt.xlabel("epochs") |
|
|
243 |
plt.ylabel("Concordance index") |
|
|
244 |
plt.savefig(results_dir_dataset + "/convergence.png",dpi=300) |
|
|
245 |
plt.close() |
|
|
246 |
|
|
|
247 |
|