Switch to unified view

a b/Cross validation/MOLI Complete/GemcitabineTCGA_cvSoftTripletClassifierNetv15.1_Script.py
1
import torch 
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import torch.optim as optim
5
import numpy as np
6
import matplotlib
7
matplotlib.use('Agg')
8
import matplotlib.pyplot as plt
9
import matplotlib.gridspec as gridspec
10
import pandas as pd
11
import math
12
import sklearn.preprocessing as sk
13
import seaborn as sns
14
from sklearn import metrics
15
from sklearn.feature_selection import VarianceThreshold
16
from sklearn.model_selection import train_test_split
17
from utils import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
18
from metrics import AverageNonzeroTripletsMetric
19
from torch.utils.data.sampler import WeightedRandomSampler
20
from sklearn.metrics import roc_auc_score
21
from sklearn.metrics import average_precision_score
22
import random
23
from random import randint
24
from sklearn.model_selection import StratifiedKFold
25
26
save_results_to = '/home/hnoghabi/SoftClassifierTripNetv15.1/GemcitabineTCGA/'
27
seed = 42
28
torch.manual_seed(seed)
29
30
max_iter = 50
31
32
GDSCE = pd.read_csv("GDSC_exprs.Gemcitabine.eb_with.TCGA_exprs.Gemcitabine.tsv", 
33
                    sep = "\t", index_col=0, decimal = ",")
34
GDSCE = pd.DataFrame.transpose(GDSCE)
35
36
TCGAE = pd.read_csv("TCGA_exprs.Gemcitabine.eb_with.GDSC_exprs.Gemcitabine.tsv", 
37
                   sep = "\t", index_col=0, decimal = ",")
38
TCGAE = pd.DataFrame.transpose(TCGAE)
39
40
TCGAM = pd.read_csv("TCGA_mutations.Gemcitabine.tsv", 
41
                   sep = "\t", index_col=0, decimal = ".")
42
TCGAM = pd.DataFrame.transpose(TCGAM)
43
TCGAM = TCGAM.loc[:,~TCGAM.columns.duplicated()]
44
45
TCGAC = pd.read_csv("TCGA_CNA.Gemcitabine.tsv", 
46
                   sep = "\t", index_col=0, decimal = ".")
47
TCGAC = pd.DataFrame.transpose(TCGAC)
48
TCGAC = TCGAC.loc[:,~TCGAC.columns.duplicated()]
49
50
GDSCM = pd.read_csv("GDSC_mutations.Gemcitabine.tsv", 
51
                    sep = "\t", index_col=0, decimal = ".")
52
GDSCM = pd.DataFrame.transpose(GDSCM)
53
GDSCM = GDSCM.loc[:,~GDSCM.columns.duplicated()]
54
55
GDSCC = pd.read_csv("GDSC_CNA.Gemcitabine.tsv", 
56
                    sep = "\t", index_col=0, decimal = ".")
57
GDSCC.drop_duplicates(keep='last')
58
GDSCC = pd.DataFrame.transpose(GDSCC)
59
GDSCC = GDSCC.loc[:,~GDSCC.columns.duplicated()]
60
61
selector = VarianceThreshold(0.05)
62
selector.fit_transform(GDSCE)
63
GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]
64
65
TCGAC = TCGAC.fillna(0)
66
TCGAC[TCGAC != 0.0] = 1
67
TCGAM = TCGAM.fillna(0)
68
TCGAM[TCGAM != 0.0] = 1
69
GDSCM = GDSCM.fillna(0)
70
GDSCM[GDSCM != 0.0] = 1
71
GDSCC = GDSCC.fillna(0)
72
GDSCC[GDSCC != 0.0] = 1
73
74
ls = GDSCE.columns.intersection(GDSCM.columns)
75
ls = ls.intersection(GDSCC.columns)
76
ls = ls.intersection(TCGAE.columns)
77
ls = ls.intersection(TCGAM.columns)
78
ls = ls.intersection(TCGAC.columns)
79
ls2 = GDSCE.index.intersection(GDSCM.index)
80
ls2 = ls2.intersection(GDSCC.index)
81
ls3 = TCGAE.index.intersection(TCGAM.index)
82
ls3 = ls3.intersection(TCGAC.index)
83
ls = pd.unique(ls)
84
85
TCGAE = TCGAE.loc[ls3,ls]
86
TCGAM = TCGAM.loc[ls3,ls]
87
TCGAC = TCGAC.loc[ls3,ls]
88
GDSCE = GDSCE.loc[ls2,ls]
89
GDSCM = GDSCM.loc[ls2,ls]
90
GDSCC = GDSCC.loc[ls2,ls]
91
92
GDSCR = pd.read_csv("GDSC_response.Gemcitabine.tsv", 
93
                    sep = "\t", index_col=0, decimal = ",")
94
TCGAR = pd.read_csv("TCGA_response.Gemcitabine.tsv", 
95
                       sep = "\t", index_col=0, decimal = ",")
96
97
GDSCR.rename(mapper = str, axis = 'index', inplace = True)
98
GDSCR = GDSCR.loc[ls2,:]
99
#GDSCR.loc[GDSCR.iloc[:,0] == 'R','response'] = 0
100
#GDSCR.loc[GDSCR.iloc[:,0] == 'S','response'] = 1
101
102
TCGAR = TCGAR.loc[ls3,:]
103
#TCGAR.loc[TCGAR.iloc[:,1] == 'R','response'] = 0
104
#TCGAR.loc[TCGAR.iloc[:,1] == 'S','response'] = 1
105
106
d = {"R":0,"S":1}
107
GDSCR["response"] = GDSCR.loc[:,"response"].apply(lambda x: d[x])
108
TCGAR["response"] = TCGAR.loc[:,"response"].apply(lambda x: d[x])
109
110
Y = GDSCR['response'].values
111
#y_test = TCGAR['response'].values
112
113
ls_mb_size = [13, 30, 64]
114
ls_h_dim = [64, 32, 16]
115
ls_marg = [1, 1.5, 2]
116
ls_lr = [00.05, 0.01, 0.001, 0.005, 0.0005, 0.0001]
117
ls_epoch = [20, 50, 10, 15, 30, 40, 60, 70, 80, 90, 100]
118
ls_rate = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
119
ls_wd = [0.01, 0.001, 0.1, 0.0001]
120
ls_lam = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
121
122
skf = StratifiedKFold(n_splits=5, random_state=42)
123
    
124
for iters in range(max_iter):
125
    k = 0
126
    mbs = random.choice(ls_mb_size)
127
    hdm1 = random.choice(ls_h_dim)
128
    hdm2 = hdm1
129
    hdm3 = hdm1
130
    mrg = random.choice(ls_marg)
131
    lre = random.choice(ls_lr)
132
    lrm = random.choice(ls_lr)
133
    lrc = random.choice(ls_lr)
134
    lrCL = random.choice(ls_lr)
135
    epch = random.choice(ls_epoch)
136
    rate1 = 0.5
137
    rate2 = 0.5
138
    rate3 = 0.5
139
    rate4 = 0.5  
140
    wd = random.choice(ls_wd)   
141
    lam = random.choice(ls_lam)   
142
143
    for train_index, test_index in skf.split(GDSCE.values, Y):
144
        k = k + 1
145
        X_trainE = GDSCE.values[train_index,:]
146
        X_testE =  GDSCE.values[test_index,:]
147
        X_trainM = GDSCM.values[train_index,:]
148
        X_testM = GDSCM.values[test_index,:]
149
        X_trainC = GDSCC.values[train_index,:]
150
        X_testC = GDSCM.values[test_index,:]
151
        y_trainE = Y[train_index]
152
        y_testE = Y[test_index]
153
        
154
        scalerGDSC = sk.StandardScaler()
155
        scalerGDSC.fit(X_trainE)
156
        X_trainE = scalerGDSC.transform(X_trainE)
157
        X_testE = scalerGDSC.transform(X_testE)
158
159
        X_trainM = np.nan_to_num(X_trainM)
160
        X_trainC = np.nan_to_num(X_trainC)
161
        X_testM = np.nan_to_num(X_testM)
162
        X_testC = np.nan_to_num(X_testC)
163
        
164
        TX_testE = torch.FloatTensor(X_testE)
165
        TX_testM = torch.FloatTensor(X_testM)
166
        TX_testC = torch.FloatTensor(X_testC)
167
        ty_testE = torch.FloatTensor(y_testE.astype(int))
168
        
169
        #Train
170
        class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
171
        weight = 1. / class_sample_count
172
        samples_weight = np.array([weight[t] for t in y_trainE])
173
174
        samples_weight = torch.from_numpy(samples_weight)
175
        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)
176
177
        mb_size = mbs
178
179
        trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(X_trainM), 
180
                                                      torch.FloatTensor(X_trainC), torch.FloatTensor(y_trainE.astype(int)))
181
182
        trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)
183
184
        n_sampE, IE_dim = X_trainE.shape
185
        n_sampM, IM_dim = X_trainM.shape
186
        n_sampC, IC_dim = X_trainC.shape
187
188
        h_dim1 = hdm1
189
        h_dim2 = hdm2
190
        h_dim3 = hdm3        
191
        Z_in = h_dim1 + h_dim2 + h_dim3
192
        marg = mrg
193
        lrE = lre
194
        lrM = lrm
195
        lrC = lrc
196
        epoch = epch
197
198
        costtr = []
199
        auctr = []
200
        costts = []
201
        aucts = []
202
203
        triplet_selector = RandomNegativeTripletSelector(marg)
204
        triplet_selector2 = AllTripletSelector()
205
206
        class AEE(nn.Module):
207
            def __init__(self):
208
                super(AEE, self).__init__()
209
                self.EnE = torch.nn.Sequential(
210
                    nn.Linear(IE_dim, h_dim1),
211
                    nn.BatchNorm1d(h_dim1),
212
                    nn.ReLU(),
213
                    nn.Dropout(rate1))
214
            def forward(self, x):
215
                output = self.EnE(x)
216
                return output
217
218
        class AEM(nn.Module):
219
            def __init__(self):
220
                super(AEM, self).__init__()
221
                self.EnM = torch.nn.Sequential(
222
                    nn.Linear(IM_dim, h_dim2),
223
                    nn.BatchNorm1d(h_dim2),
224
                    nn.ReLU(),
225
                    nn.Dropout(rate2))
226
            def forward(self, x):
227
                output = self.EnM(x)
228
                return output    
229
230
231
        class AEC(nn.Module):
232
            def __init__(self):
233
                super(AEC, self).__init__()
234
                self.EnC = torch.nn.Sequential(
235
                    nn.Linear(IM_dim, h_dim3),
236
                    nn.BatchNorm1d(h_dim3),
237
                    nn.ReLU(),
238
                    nn.Dropout(rate3))
239
            def forward(self, x):
240
                output = self.EnC(x)
241
                return output    
242
243
        class OnlineTriplet(nn.Module):
244
            def __init__(self, marg, triplet_selector):
245
                super(OnlineTriplet, self).__init__()
246
                self.marg = marg
247
                self.triplet_selector = triplet_selector
248
            def forward(self, embeddings, target):
249
                triplets = self.triplet_selector.get_triplets(embeddings, target)
250
                return triplets
251
252
        class OnlineTestTriplet(nn.Module):
253
            def __init__(self, marg, triplet_selector):
254
                super(OnlineTestTriplet, self).__init__()
255
                self.marg = marg
256
                self.triplet_selector = triplet_selector
257
            def forward(self, embeddings, target):
258
                triplets = self.triplet_selector.get_triplets(embeddings, target)
259
                return triplets    
260
261
        class Classifier(nn.Module):
262
            def __init__(self):
263
                super(Classifier, self).__init__()
264
                self.FC = torch.nn.Sequential(
265
                    nn.Linear(Z_in, 1),
266
                    nn.Dropout(rate4),
267
                    nn.Sigmoid())
268
            def forward(self, x):
269
                return self.FC(x)
270
271
        torch.cuda.manual_seed_all(42)
272
273
        AutoencoderE = AEE()
274
        AutoencoderM = AEM()
275
        AutoencoderC = AEC()
276
277
        solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
278
        solverM = optim.Adagrad(AutoencoderM.parameters(), lr=lrM)
279
        solverC = optim.Adagrad(AutoencoderC.parameters(), lr=lrC)
280
281
        trip_criterion = torch.nn.TripletMarginLoss(margin=marg, p=2)
282
        TripSel = OnlineTriplet(marg, triplet_selector)
283
        TripSel2 = OnlineTestTriplet(marg, triplet_selector2)
284
285
        Clas = Classifier()
286
        SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
287
        C_loss = torch.nn.BCELoss()
288
289
        for it in range(epoch):
290
291
            epoch_cost4 = 0
292
            epoch_cost3 = []
293
            num_minibatches = int(n_sampE / mb_size) 
294
295
            for i, (dataE, dataM, dataC, target) in enumerate(trainLoader):
296
                flag = 0
297
                AutoencoderE.train()
298
                AutoencoderM.train()
299
                AutoencoderC.train()
300
                Clas.train()
301
302
                if torch.mean(target)!=0. and torch.mean(target)!=1.: 
303
                    ZEX = AutoencoderE(dataE)
304
                    ZMX = AutoencoderM(dataM)
305
                    ZCX = AutoencoderC(dataC)
306
307
                    ZT = torch.cat((ZEX, ZMX, ZCX), 1)
308
                    ZT = F.normalize(ZT, p=2, dim=0)
309
                    Pred = Clas(ZT)
310
311
                    Triplets = TripSel2(ZT, target)
312
                    loss = lam * trip_criterion(ZT[Triplets[:,0],:],ZT[Triplets[:,1],:],ZT[Triplets[:,2],:]) + C_loss(Pred,target.view(-1,1))     
313
314
                    y_true = target.view(-1,1)
315
                    y_pred = Pred
316
                    AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 
317
318
                    solverE.zero_grad()
319
                    solverM.zero_grad()
320
                    solverC.zero_grad()
321
                    SolverClass.zero_grad()
322
323
                    loss.backward()
324
325
                    solverE.step()
326
                    solverM.step()
327
                    solverC.step()
328
                    SolverClass.step()
329
330
                    epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
331
                    epoch_cost3.append(AUC)
332
                    flag = 1
333
334
            if flag == 1:
335
                costtr.append(torch.mean(epoch_cost4))
336
                auctr.append(np.mean(epoch_cost3))
337
                print('Iter-{}; Total loss: {:.4}'.format(it, loss))
338
339
            with torch.no_grad():
340
341
                AutoencoderE.eval()
342
                AutoencoderM.eval()
343
                AutoencoderC.eval()
344
                Clas.eval()
345
346
                ZET = AutoencoderE(TX_testE)
347
                ZMT = AutoencoderM(TX_testM)
348
                ZCT = AutoencoderC(TX_testC)
349
350
                ZTT = torch.cat((ZET, ZMT, ZCT), 1)
351
                ZTT = F.normalize(ZTT, p=2, dim=0)
352
                PredT = Clas(ZTT)
353
354
                TripletsT = TripSel2(ZTT, ty_testE)
355
                lossT = lam * trip_criterion(ZTT[TripletsT[:,0],:], ZTT[TripletsT[:,1],:], ZTT[TripletsT[:,2],:]) + C_loss(PredT,ty_testE.view(-1,1))
356
357
                y_truet = ty_testE.view(-1,1)
358
                y_predt = PredT
359
                AUCt = roc_auc_score(y_truet.detach().numpy(),y_predt.detach().numpy())        
360
361
                costts.append(lossT)
362
                aucts.append(AUCt)
363
364
        plt.plot(np.squeeze(costtr), '-r',np.squeeze(costts), '-b')
365
        plt.ylabel('Total cost')
366
        plt.xlabel('iterations (per tens)')
367
368
        title = 'Cost GemcitabineT iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
369
                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)
370
371
        plt.suptitle(title)
372
        plt.savefig(save_results_to + title + '.png', dpi = 150)
373
        plt.close()
374
375
        plt.plot(np.squeeze(auctr), '-r',np.squeeze(aucts), '-b')
376
        plt.ylabel('AUC')
377
        plt.xlabel('iterations (per tens)')
378
379
        title = 'AUC GemcitabineT iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
380
                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)        
381
382
        plt.suptitle(title)
383
        plt.savefig(save_results_to + title + '.png', dpi = 150)
384
        plt.close()