Switch to side-by-side view

--- a
+++ b/Cross validation/EGFR Pan-drug/EGFR_TripClassNetv1.py
@@ -0,0 +1,420 @@
+import torch 
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import pandas as pd
+import math
+import sklearn.preprocessing as sk
+import seaborn as sns
+from sklearn import metrics
+from sklearn.feature_selection import VarianceThreshold
+from sklearn.model_selection import train_test_split
+from utils import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
+from metrics import AverageNonzeroTripletsMetric
+from torch.utils.data.sampler import WeightedRandomSampler
+from sklearn.metrics import roc_auc_score
+from sklearn.metrics import average_precision_score
+import random
+from random import randint
+from sklearn.model_selection import StratifiedKFold
+
+save_results_to = '/home/hnoghabi/EGFR/'
+torch.manual_seed(42)
+
+max_iter = 50
+
+GDSCE = pd.read_csv("GDSC_exprs.z.EGFRi.tsv", 
+                    sep = "\t", index_col=0, decimal = ",")
+GDSCE = pd.DataFrame.transpose(GDSCE)
+
+GDSCM = pd.read_csv("GDSC_mutations.EGFRi.tsv", 
+                    sep = "\t", index_col=0, decimal = ".")
+GDSCM = pd.DataFrame.transpose(GDSCM)
+GDSCM = GDSCM.loc[:,~GDSCM.columns.duplicated()]
+
+GDSCC = pd.read_csv("GDSC_CNA.EGFRi.tsv", 
+                    sep = "\t", index_col=0, decimal = ".")
+GDSCC.drop_duplicates(keep='last')
+GDSCC = pd.DataFrame.transpose(GDSCC)
+GDSCC = GDSCC.loc[:,~GDSCC.columns.duplicated()]
+
+PDXEerlo = pd.read_csv("PDX_exprs.Erlotinib.eb_with.GDSC_exprs.Erlotinib.tsv", 
+                   sep = "\t", index_col=0, decimal = ",")
+PDXEerlo = pd.DataFrame.transpose(PDXEerlo)
+PDXMerlo = pd.read_csv("PDX_mutations.Erlotinib.tsv", 
+                   sep = "\t", index_col=0, decimal = ",")
+PDXMerlo = pd.DataFrame.transpose(PDXMerlo)
+PDXCerlo = pd.read_csv("PDX_CNA.Erlotinib.tsv", 
+                   sep = "\t", index_col=0, decimal = ",")
+PDXCerlo.drop_duplicates(keep='last')
+PDXCerlo = pd.DataFrame.transpose(PDXCerlo)
+PDXCerlo = PDXCerlo.loc[:,~PDXCerlo.columns.duplicated()]
+
+PDXEcet = pd.read_csv("PDX_exprs.Cetuximab.eb_with.GDSC_exprs.Cetuximab.tsv", 
+                   sep = "\t", index_col=0, decimal = ",")
+PDXEcet = pd.DataFrame.transpose(PDXEcet)
+PDXMcet = pd.read_csv("PDX_mutations.Cetuximab.tsv", 
+                   sep = "\t", index_col=0, decimal = ",")
+PDXMcet = pd.DataFrame.transpose(PDXMcet)
+PDXCcet = pd.read_csv("PDX_CNA.Cetuximab.tsv", 
+                   sep = "\t", index_col=0, decimal = ",")
+PDXCcet.drop_duplicates(keep='last')
+PDXCcet = pd.DataFrame.transpose(PDXCcet)
+PDXCcet = PDXCcet.loc[:,~PDXCcet.columns.duplicated()]
+
+selector = VarianceThreshold(0.05)
+selector.fit_transform(GDSCE)
+GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]
+
+GDSCM = GDSCM.fillna(0)
+GDSCM[GDSCM != 0.0] = 1
+GDSCC = GDSCC.fillna(0)
+GDSCC[GDSCC != 0.0] = 1
+
+ls = GDSCE.columns.intersection(GDSCM.columns)
+ls = ls.intersection(GDSCC.columns)
+ls = ls.intersection(PDXEerlo.columns)
+ls = ls.intersection(PDXMerlo.columns)
+ls = ls.intersection(PDXCerlo.columns)
+ls = ls.intersection(PDXEcet.columns)
+ls = ls.intersection(PDXMcet.columns)
+ls = ls.intersection(PDXCcet.columns)
+ls2 = GDSCE.index.intersection(GDSCM.index)
+ls2 = ls2.intersection(GDSCC.index)
+ls3 = PDXEerlo.index.intersection(PDXMerlo.index)
+ls3 = ls3.intersection(PDXCerlo.index)
+ls4 = PDXEcet.index.intersection(PDXMcet.index)
+ls4 = ls4.intersection(PDXCcet.index)
+ls = pd.unique(ls)
+
+PDXEerlo = PDXEerlo.loc[ls3,ls]
+PDXMerlo = PDXMerlo.loc[ls3,ls]
+PDXCerlo = PDXCerlo.loc[ls3,ls]
+PDXEcet = PDXEcet.loc[ls4,ls]
+PDXMcet = PDXMcet.loc[ls4,ls]
+PDXCcet = PDXCcet.loc[ls4,ls]
+GDSCE = GDSCE.loc[:,ls]
+GDSCM = GDSCM.loc[:,ls]
+GDSCC = GDSCC.loc[:,ls]
+
+GDSCR = pd.read_csv("GDSC_response.EGFRi.tsv", 
+                    sep = "\t", index_col=0, decimal = ",")
+
+GDSCR.rename(mapper = str, axis = 'index', inplace = True)
+
+d = {"R":0,"S":1}
+GDSCR["response"] = GDSCR.loc[:,"response"].apply(lambda x: d[x])
+
+responses = GDSCR
+drugs = set(responses["drug"].values)
+exprs_z = GDSCE
+cna = GDSCC
+mut = GDSCM
+expression_zscores = []
+CNA=[]
+mutations = []
+for drug in drugs:
+    samples = responses.loc[responses["drug"]==drug,:].index.values
+    e_z = exprs_z.loc[samples,:]
+    c = cna.loc[samples,:]
+    m = mut.loc[samples,:]
+    m = mut.loc[samples,:]
+    # next 3 rows if you want non-unique sample names
+    e_z.rename(lambda x : str(x)+"_"+drug, axis = "index", inplace=True)
+    c.rename(lambda x : str(x)+"_"+drug, axis = "index", inplace=True)
+    m.rename(lambda x : str(x)+"_"+drug, axis = "index", inplace=True)
+    expression_zscores.append(e_z)
+    CNA.append(c)
+    mutations.append(m)
+responses.index = responses.index.values +"_"+responses["drug"].values
+GDSCEv2 = pd.concat(expression_zscores, axis =0 )
+GDSCCv2 = pd.concat(CNA, axis =0 )
+GDSCMv2 = pd.concat(mutations, axis =0 )
+GDSCRv2 = responses
+
+ls2 = GDSCEv2.index.intersection(GDSCMv2.index)
+ls2 = ls2.intersection(GDSCCv2.index)
+GDSCEv2 = GDSCEv2.loc[ls2,:]
+GDSCMv2 = GDSCMv2.loc[ls2,:]
+GDSCCv2 = GDSCCv2.loc[ls2,:]
+GDSCRv2 = GDSCRv2.loc[ls2,:]
+
+ls_mb_size = [8, 16, 32, 64]
+ls_h_dim = [1023, 512, 256, 128, 64, 32, 16]
+ls_marg = [0.5, 1, 1.5, 2, 2.5]
+ls_lr = [0.5, 0.1, 0.05, 0.01, 0.001, 0.005, 0.0005, 0.0001,0.00005, 0.00001]
+ls_epoch = [20, 50, 10, 15, 30, 40, 60, 70, 80, 90, 100]
+ls_rate = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
+ls_wd = [0.01, 0.001, 0.1, 0.0001]
+ls_lam = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
+
+Y = GDSCRv2['response'].values
+
+skf = StratifiedKFold(n_splits=5, random_state=42)
+
+for iters in range(max_iter):
+    k = 0
+    mbs = random.choice(ls_mb_size)
+    hdm1 = random.choice(ls_h_dim)
+    hdm2 = random.choice(ls_h_dim)
+    hdm3 = random.choice(ls_h_dim) 
+    mrg = random.choice(ls_marg)
+    lre = random.choice(ls_lr)
+    lrm = random.choice(ls_lr)
+    lrc = random.choice(ls_lr)
+    lrCL = random.choice(ls_lr)
+    epch = random.choice(ls_epoch)
+    rate1 = random.choice(ls_rate)
+    rate2 = random.choice(ls_rate)
+    rate3 = random.choice(ls_rate)
+    rate4 = random.choice(ls_rate)    
+    wd = random.choice(ls_wd)   
+    lam = random.choice(ls_lam)   
+
+    for train_index, test_index in skf.split(GDSCEv2.values, Y):
+        k = k + 1
+        X_trainE = GDSCEv2.values[train_index,:]
+        X_testE =  GDSCEv2.values[test_index,:]
+        X_trainM = GDSCMv2.values[train_index,:]
+        X_testM = GDSCMv2.values[test_index,:]
+        X_trainC = GDSCCv2.values[train_index,:]
+        X_testC = GDSCMv2.values[test_index,:]
+        y_trainE = Y[train_index]
+        y_testE = Y[test_index]
+        
+        scalerGDSC = sk.StandardScaler()
+        scalerGDSC.fit(X_trainE)
+        X_trainE = scalerGDSC.transform(X_trainE)
+        X_testE = scalerGDSC.transform(X_testE)
+
+        X_trainM = np.nan_to_num(X_trainM)
+        X_trainC = np.nan_to_num(X_trainC)
+        X_testM = np.nan_to_num(X_testM)
+        X_testC = np.nan_to_num(X_testC)
+        
+        TX_testE = torch.FloatTensor(X_testE)
+        TX_testM = torch.FloatTensor(X_testM)
+        TX_testC = torch.FloatTensor(X_testC)
+        ty_testE = torch.FloatTensor(y_testE.astype(int))
+        
+        #Train
+        class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
+        weight = 1. / class_sample_count
+        samples_weight = np.array([weight[t] for t in y_trainE])
+
+        samples_weight = torch.from_numpy(samples_weight)
+        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)
+
+        mb_size = mbs
+
+        trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(X_trainM), 
+                                                      torch.FloatTensor(X_trainC), torch.FloatTensor(y_trainE.astype(int)))
+
+        trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)
+
+        n_sampE, IE_dim = X_trainE.shape
+        n_sampM, IM_dim = X_trainM.shape
+        n_sampC, IC_dim = X_trainC.shape
+
+        h_dim1 = hdm1
+        h_dim2 = hdm2
+        h_dim3 = hdm3        
+        Z_in = h_dim1 + h_dim2 + h_dim3
+        marg = mrg
+        lrE = lre
+        lrM = lrm
+        lrC = lrc
+        epoch = epch
+
+        costtr = []
+        auctr = []
+        costts = []
+        aucts = []
+
+        triplet_selector = RandomNegativeTripletSelector(marg)
+        triplet_selector2 = AllTripletSelector()
+
+        class AEE(nn.Module):
+            def __init__(self):
+                super(AEE, self).__init__()
+                self.EnE = torch.nn.Sequential(
+                    nn.Linear(IE_dim, h_dim1),
+                    nn.BatchNorm1d(h_dim1),
+                    nn.ReLU(),
+                    nn.Dropout(rate1))
+            def forward(self, x):
+                output = self.EnE(x)
+                return output
+
+        class AEM(nn.Module):
+            def __init__(self):
+                super(AEM, self).__init__()
+                self.EnM = torch.nn.Sequential(
+                    nn.Linear(IM_dim, h_dim2),
+                    nn.BatchNorm1d(h_dim2),
+                    nn.ReLU(),
+                    nn.Dropout(rate2))
+            def forward(self, x):
+                output = self.EnM(x)
+                return output    
+
+
+        class AEC(nn.Module):
+            def __init__(self):
+                super(AEC, self).__init__()
+                self.EnC = torch.nn.Sequential(
+                    nn.Linear(IM_dim, h_dim3),
+                    nn.BatchNorm1d(h_dim3),
+                    nn.ReLU(),
+                    nn.Dropout(rate3))
+            def forward(self, x):
+                output = self.EnC(x)
+                return output    
+
+        class OnlineTriplet(nn.Module):
+            def __init__(self, marg, triplet_selector):
+                super(OnlineTriplet, self).__init__()
+                self.marg = marg
+                self.triplet_selector = triplet_selector
+            def forward(self, embeddings, target):
+                triplets = self.triplet_selector.get_triplets(embeddings, target)
+                return triplets
+
+        class OnlineTestTriplet(nn.Module):
+            def __init__(self, marg, triplet_selector):
+                super(OnlineTestTriplet, self).__init__()
+                self.marg = marg
+                self.triplet_selector = triplet_selector
+            def forward(self, embeddings, target):
+                triplets = self.triplet_selector.get_triplets(embeddings, target)
+                return triplets    
+
+        class Classifier(nn.Module):
+            def __init__(self):
+                super(Classifier, self).__init__()
+                self.FC = torch.nn.Sequential(
+                    nn.Linear(Z_in, 1),
+                    nn.Dropout(rate4),
+                    nn.Sigmoid())
+            def forward(self, x):
+                return self.FC(x)
+
+        torch.cuda.manual_seed_all(42)
+
+        AutoencoderE = AEE()
+        AutoencoderM = AEM()
+        AutoencoderC = AEC()
+
+        solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
+        solverM = optim.Adagrad(AutoencoderM.parameters(), lr=lrM)
+        solverC = optim.Adagrad(AutoencoderC.parameters(), lr=lrC)
+
+        trip_criterion = torch.nn.TripletMarginLoss(margin=marg, p=2)
+        TripSel = OnlineTriplet(marg, triplet_selector)
+        TripSel2 = OnlineTestTriplet(marg, triplet_selector2)
+
+        Clas = Classifier()
+        SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
+        C_loss = torch.nn.BCELoss()
+
+        for it in range(epoch):
+
+            epoch_cost4 = 0
+            epoch_cost3 = []
+            num_minibatches = int(n_sampE / mb_size) 
+
+            for i, (dataE, dataM, dataC, target) in enumerate(trainLoader):
+                flag = 0
+                AutoencoderE.train()
+                AutoencoderM.train()
+                AutoencoderC.train()
+                Clas.train()
+
+                if torch.mean(target)!=0. and torch.mean(target)!=1.: 
+                    ZEX = AutoencoderE(dataE)
+                    ZMX = AutoencoderM(dataM)
+                    ZCX = AutoencoderC(dataC)
+
+                    ZT = torch.cat((ZEX, ZMX, ZCX), 1)
+                    ZT = F.normalize(ZT, p=2, dim=0)
+                    Pred = Clas(ZT)
+
+                    Triplets = TripSel2(ZT, target)
+                    loss = lam * trip_criterion(ZT[Triplets[:,0],:],ZT[Triplets[:,1],:],ZT[Triplets[:,2],:]) + C_loss(Pred,target.view(-1,1))     
+
+                    y_true = target.view(-1,1)
+                    y_pred = Pred
+                    AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 
+
+                    solverE.zero_grad()
+                    solverM.zero_grad()
+                    solverC.zero_grad()
+                    SolverClass.zero_grad()
+
+                    loss.backward()
+
+                    solverE.step()
+                    solverM.step()
+                    solverC.step()
+                    SolverClass.step()
+
+                    epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
+                    epoch_cost3.append(AUC)
+                    flag = 1
+
+            if flag == 1:
+                costtr.append(torch.mean(epoch_cost4))
+                auctr.append(np.mean(epoch_cost3))
+                print('Iter-{}; Total loss: {:.4}'.format(it, loss))
+
+            with torch.no_grad():
+
+                AutoencoderE.eval()
+                AutoencoderM.eval()
+                AutoencoderC.eval()
+                Clas.eval()
+
+                ZET = AutoencoderE(TX_testE)
+                ZMT = AutoencoderM(TX_testM)
+                ZCT = AutoencoderC(TX_testC)
+
+                ZTT = torch.cat((ZET, ZMT, ZCT), 1)
+                ZTT = F.normalize(ZTT, p=2, dim=0)
+                PredT = Clas(ZTT)
+
+                TripletsT = TripSel2(ZTT, ty_testE)
+                lossT = lam * trip_criterion(ZTT[TripletsT[:,0],:], ZTT[TripletsT[:,1],:], ZTT[TripletsT[:,2],:]) + C_loss(PredT,ty_testE.view(-1,1))
+
+                y_truet = ty_testE.view(-1,1)
+                y_predt = PredT
+                AUCt = roc_auc_score(y_truet.detach().numpy(),y_predt.detach().numpy())        
+
+                costts.append(lossT)
+                aucts.append(AUCt)
+
+        plt.plot(np.squeeze(costtr), '-r',np.squeeze(costts), '-b')
+        plt.ylabel('Total cost')
+        plt.xlabel('iterations (per tens)')
+
+        title = 'Cost Cetuximab iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
+                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)
+
+        plt.suptitle(title)
+        plt.savefig(save_results_to + title + '.png', dpi = 150)
+        plt.close()
+
+        plt.plot(np.squeeze(auctr), '-r',np.squeeze(aucts), '-b')
+        plt.ylabel('AUC')
+        plt.xlabel('iterations (per tens)')
+
+        title = 'AUC Cetuximab iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
+                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)        
+
+        plt.suptitle(title)
+        plt.savefig(save_results_to + title + '.png', dpi = 150)
+        plt.close()