Switch to side-by-side view

--- a
+++ b/experiments/other cancer/main_LUAD.py
@@ -0,0 +1,247 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+@author: Zhi Huang
+"""
+
+import sys, os
+sys.path.append("/home/zhihuan/Documents/SALMON/model")
+import SALMON
+import pandas as pd
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import transforms
+from torch.autograd import Variable
+from collections import Counter
+import pandas as pd
+import math
+import random
+from imblearn.over_sampling import RandomOverSampler
+from lifelines.statistics import logrank_test
+import json
+import tables
+import logging
+import csv
+import numpy as np
+import optunity
+import pickle
+import time
+from sklearn.model_selection import KFold
+from sklearn import preprocessing
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+plt.ioff()
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--dataset_dir', default='/home/zhihuan/Documents/SALMON/data/LUAD/multiomics_preprocessing_results/', help="datasets")
+    parser.add_argument('--num_epochs', type=int, default=100, help="Number of epochs to train for. Default: 100")
+    parser.add_argument('--measure_while_training', action='store_true', default=False, help='disables measure while training (make program faster)')
+    parser.add_argument('--batch_size', type=int, default=64, help="Number of batches to train/test for. Default: 256")
+    parser.add_argument('--dataset', type=int, default=7)
+    parser.add_argument('--nocuda', action='store_true', default=False, help='disables CUDA training')
+    parser.add_argument('--verbose', default=1, type=int)
+    parser.add_argument('--results_dir', default='/home/zhihuan/Documents/SALMON/experiments/Results/LUAD', help="results dir")
+    return parser.parse_args()
+
+if __name__=='__main__':
+    torch.cuda.empty_cache()
+    args = parse_args()
+
+    # model file
+    num_epochs = args.num_epochs
+    batch_size = args.batch_size
+    learning_rate_range = 10**np.arange(-4,-1,0.3)
+    cuda = True
+    verbose = 0
+    measure_while_training = True
+    dropout_rate = 0
+    lambda_1 = 1e-6 # L1
+    
+    # 5-fold data
+    tempdata = {}
+    tempdata['clinical'] = pd.read_csv(args.dataset_dir + 'clinical.csv', index_col = 0).reset_index(drop = True)
+    tempdata['mRNAseq_eigengene'] = pd.read_csv(args.dataset_dir + 'mRNAseq_eigengene_matrix.csv', index_col = 0).reset_index(drop = True)
+    tempdata['miRNAseq_eigengene'] = pd.read_csv(args.dataset_dir + 'miRNAseq_eigengene_matrix.csv', index_col = 0).reset_index(drop = True)
+    tempdata['TMB'] = pd.read_csv(args.dataset_dir + 'TMB.csv', index_col = 0).reset_index(drop = True)
+    tempdata['CNB'] = pd.read_csv(args.dataset_dir + 'CNB.csv', index_col = 0).reset_index(drop = True)
+    tempdata['CNB']['log2_LENGTH_KB'] = np.log2(tempdata['CNB']['LENGTH_KB'].values + 1)
+    
+    print('0:MALE\t\t1:FEMALE\n0:Alive\t\t1:Dead')
+    tempdata['clinical']['gender'] = (tempdata['clinical']['gender'].values == 'MALE').astype(int)
+    tempdata['clinical']['vital_status'] = (tempdata['clinical']['vital_status'].values == 'Dead').astype(int)
+    
+    
+    data = {}
+    data['x'] = pd.concat((tempdata['mRNAseq_eigengene'], tempdata['miRNAseq_eigengene'], tempdata['CNB']['log2_LENGTH_KB'], tempdata['TMB']['All_TMB'], tempdata['clinical'][['gender','age_at_initial_pathologic_diagnosis']]), axis = 1).values.astype(np.double)
+    all_column_names = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
+                            ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
+                            ['CNB', 'TMB', 'GENDER', 'AGE']
+    print('perform min-max scaler on all input features')
+    scaler = preprocessing.MinMaxScaler()
+    scaler.fit(data['x'])
+    data['x'] = scaler.transform(data['x'])
+    
+    data['e'] = tempdata['clinical']['vital_status'].values.astype(np.int32)
+    data['t'] = tempdata['clinical']['survival_days'].values.astype(np.double)
+    
+    if args.dataset == 1:
+        dataset_subset = "1_RNAseq"
+        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])]
+        
+    elif args.dataset == 2:
+        dataset_subset = "2_miRNAseq"
+        data['column_names'] = ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])]
+        
+    elif args.dataset == 3:
+        dataset_subset = "3_RNAseq+miRNAseq"
+        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
+                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])]
+    elif args.dataset == 4:
+        dataset_subset = "4_RNAseq+miRNAseq+cnb+tmb"
+        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
+                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
+                                ['CNB', 'TMB']
+    elif args.dataset == 5:
+        dataset_subset = "5_RNAseq+miRNAseq+clinical"
+        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
+                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
+                                ['GENDER', 'AGE']
+    elif args.dataset == 6:
+        dataset_subset = "6_cnb+tmb+clinical"
+        data['column_names'] = ['CNB', 'TMB', 'GENDER', 'AGE']
+        
+    elif args.dataset == 7:
+        dataset_subset = "7_RNAseq+miRNAseq+cnb+tmb+clinical"
+        data['column_names'] = ['mRNAseq_' + str(i+1) for i in range(tempdata['mRNAseq_eigengene'].shape[1])] + \
+                                ['miRNAseq_' + str(i+1) for i in range(tempdata['miRNAseq_eigengene'].shape[1])] + \
+                                ['CNB', 'TMB', 'GENDER', 'AGE']
+    print('subsetting data...')
+    data['x'] = data['x'][:, [i for i, c in enumerate(all_column_names) if c in data['column_names']]]
+
+    kf = KFold(n_splits=5, shuffle=True, random_state=666)
+    datasets_5folds = {}
+    for ix, (train_index, test_index) in enumerate(kf.split(data['x']), start = 1):
+        datasets_5folds[ix] = {}
+        datasets_5folds[ix]['train'] = {}
+        datasets_5folds[ix]['train']['x'] = data['x'][train_index, :]
+        datasets_5folds[ix]['train']['e'] = data['e'][train_index]
+        datasets_5folds[ix]['train']['t'] = data['t'][train_index]
+        datasets_5folds[ix]['test'] = {}
+        datasets_5folds[ix]['test']['x'] = data['x'][train_index, :]
+        datasets_5folds[ix]['test']['e'] = data['e'][train_index]
+        datasets_5folds[ix]['test']['t'] = data['t'][train_index]
+
+    for i in range(1, len(datasets_5folds) + 1):
+        print("5 fold CV -- %d/5" % i)
+        
+        # dataset
+        TIMESTRING  = time.strftime("%Y%m%d-%H.%M.%S", time.localtime())
+        
+        results_dir_dataset = args.results_dir + '/' + dataset_subset + '/run_' + TIMESTRING + '_fold_' + str(i)
+        if not os.path.exists(results_dir_dataset):
+            os.makedirs(results_dir_dataset)
+            
+        logging.basicConfig(filename=results_dir_dataset+'/mainlog.log',level=logging.DEBUG)
+    #    print("Arguments:",args)
+    #    logging.info("Arguments: %s" % args)
+        datasets = datasets_5folds[i]
+        
+        length_of_data = {}
+        length_of_data['mRNAseq'] = tempdata['mRNAseq_eigengene'].shape[1]
+        length_of_data['miRNAseq'] = tempdata['miRNAseq_eigengene'].shape[1]
+        length_of_data['CNB'] = 1
+        length_of_data['TMB'] = 1
+        length_of_data['clinical'] = 2
+        
+    # =============================================================================
+    # # Finding optimal learning rate w.r.t. concordance index
+    # =============================================================================
+        ci_list = []
+        for j, lr in enumerate(learning_rate_range):
+            print("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr))
+            logging.info("[%d/%d] current lr: %.4E" %((j+1), len(learning_rate_range), lr))
+            model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \
+                 SALMON.train(datasets, num_epochs, batch_size, lr, dropout_rate,\
+                                         lambda_1, length_of_data, cuda, measure_while_training, verbose)
+        
+            epochs_list = range(num_epochs)
+            plt.figure(figsize=(8,4))
+            plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1)
+            plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1)
+            plt.legend(['train', 'test'])
+            plt.xlabel("epochs")
+            plt.ylabel("Concordance index")
+            plt.savefig(results_dir_dataset + "/convergence_%02d_lr=%.2E.png" % (j, lr),dpi=300)
+            plt.close()
+            code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all, OS_event_test, OS_test = \
+                SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose)
+            ci_list.append(c_index_pred)
+            print("current concordance index: ", c_index_pred,"\n")
+            logging.info("current concordance index: %.10f\n" % c_index_pred)
+            
+        optimal_lr = learning_rate_range[np.argmax(ci_list)]
+        
+        print("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list)))
+        logging.info("Optimal learning rate: %.4E, optimal c-index: %.10f" % (optimal_lr, np.max(ci_list)))
+    
+    
+    # =============================================================================
+    # # Training 
+    # =============================================================================
+    
+        model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output = \
+                 SALMON.train(datasets, num_epochs, batch_size, optimal_lr, dropout_rate,\
+                                         lambda_1, length_of_data, cuda, measure_while_training, verbose)
+        code_train, loss_nn_sum, acc_train, pvalue_pred, c_index_pred, lbl_pred_all_train, OS_event_train, OS_train = \
+            SALMON.test(model, datasets, 'train', length_of_data, batch_size, cuda, verbose)
+        print("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
+        logging.info("[Final] Apply model to training set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
+    
+        code_test, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all_test, OS_event_test, OS_test = \
+            SALMON.test(model, datasets, 'test', length_of_data, batch_size, cuda, verbose)
+        print("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
+        logging.info("[Final] Apply model to testing set: c-index: %.10f, p-value: %.10e" % (c_index_pred, pvalue_pred))
+               
+        
+        with open(results_dir_dataset + '/model.pickle', 'wb') as handle:
+            pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        with open(results_dir_dataset + '/c_index_list_by_epochs.pickle', 'wb') as handle:
+            pickle.dump(c_index_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
+            
+        with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_train.pickle', 'wb') as handle:
+            pickle.dump(lbl_pred_all_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        with open(results_dir_dataset + '/OS_event_train.pickle', 'wb') as handle:
+            pickle.dump(OS_event_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        with open(results_dir_dataset + '/OS_train.pickle', 'wb') as handle:
+            pickle.dump(OS_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        with open(results_dir_dataset + '/code_train.pickle', 'wb') as handle:
+            pickle.dump(code_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
+            
+        with open(results_dir_dataset + '/hazard_ratios_lbl_pred_all_test.pickle', 'wb') as handle:
+            pickle.dump(lbl_pred_all_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        with open(results_dir_dataset + '/OS_event_test.pickle', 'wb') as handle:
+            pickle.dump(OS_event_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        with open(results_dir_dataset + '/OS_test.pickle', 'wb') as handle:
+            pickle.dump(OS_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        with open(results_dir_dataset + '/code_test.pickle', 'wb') as handle:
+            pickle.dump(code_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
+    
+    
+        epochs_list = range(num_epochs)
+        plt.figure(figsize=(8,4))
+        plt.plot(epochs_list, c_index_list['train'], "b--",linewidth=1)
+        plt.plot(epochs_list, c_index_list['test'], "g-",linewidth=1)
+        plt.legend(['train', 'test'])
+        plt.xlabel("epochs")
+        plt.ylabel("Concordance index")
+        plt.savefig(results_dir_dataset + "/convergence.png",dpi=300)
+        plt.close()
+
+