a b/Cross validation/MOLI Complete/Paclitaxel_cvSoftTripletClassifierNetv16_Script.py
1
import torch 
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import torch.optim as optim
5
import numpy as np
6
import matplotlib
7
matplotlib.use('Agg')
8
import matplotlib.pyplot as plt
9
import matplotlib.gridspec as gridspec
10
import pandas as pd
11
import math
12
import sklearn.preprocessing as sk
13
import seaborn as sns
14
from sklearn import metrics
15
from sklearn.feature_selection import VarianceThreshold
16
from sklearn.model_selection import train_test_split
17
from utils import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
18
from metrics import AverageNonzeroTripletsMetric
19
from torch.utils.data.sampler import WeightedRandomSampler
20
from sklearn.metrics import roc_auc_score
21
from sklearn.metrics import average_precision_score
22
import random
23
from random import randint
24
from sklearn.model_selection import StratifiedKFold
25
26
save_results_to = '/home/hnoghabi/SoftClassifierTripNetv16/Paclitaxel/'
27
torch.manual_seed(42)
28
29
max_iter = 50
30
31
GDSCE = pd.read_csv("GDSC_exprs.Paclitaxel.eb_with.PDX_exprs.Paclitaxel.tsv", 
32
                    sep = "\t", index_col=0, decimal = ",")
33
GDSCE = pd.DataFrame.transpose(GDSCE)
34
35
GDSCR = pd.read_csv("GDSC_response.Paclitaxel.tsv", 
36
                    sep = "\t", index_col=0, decimal = ",")
37
38
PDXE = pd.read_csv("PDX_exprs.Paclitaxel.eb_with.GDSC_exprs.Paclitaxel.tsv", 
39
                   sep = "\t", index_col=0, decimal = ",")
40
PDXE = pd.DataFrame.transpose(PDXE)
41
42
PDXM = pd.read_csv("PDX_mutations.Paclitaxel.tsv", 
43
                   sep = "\t", index_col=0, decimal = ".")
44
PDXM = pd.DataFrame.transpose(PDXM)
45
46
PDXC = pd.read_csv("PDX_CNA.Paclitaxel.tsv", 
47
                   sep = "\t", index_col=0, decimal = ".")
48
PDXC.drop_duplicates(keep='last')
49
PDXC = pd.DataFrame.transpose(PDXC)
50
PDXC = PDXC.loc[:,~PDXC.columns.duplicated()]
51
52
GDSCM = pd.read_csv("GDSC_mutations.Paclitaxel.tsv", 
53
                    sep = "\t", index_col=0, decimal = ".")
54
GDSCM = pd.DataFrame.transpose(GDSCM)
55
56
57
GDSCC = pd.read_csv("GDSC_CNA.Paclitaxel.tsv", 
58
                    sep = "\t", index_col=0, decimal = ".")
59
GDSCC.drop_duplicates(keep='last')
60
GDSCC = pd.DataFrame.transpose(GDSCC)
61
62
selector = VarianceThreshold(0.05)
63
selector.fit_transform(GDSCE)
64
GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]
65
66
PDXC = PDXC.fillna(0)
67
PDXC[PDXC != 0.0] = 1
68
PDXM = PDXM.fillna(0)
69
PDXM[PDXM != 0.0] = 1
70
GDSCM = GDSCM.fillna(0)
71
GDSCM[GDSCM != 0.0] = 1
72
GDSCC = GDSCC.fillna(0)
73
GDSCC[GDSCC != 0.0] = 1
74
75
ls = GDSCE.columns.intersection(GDSCM.columns)
76
ls = ls.intersection(GDSCC.columns)
77
ls = ls.intersection(PDXE.columns)
78
ls = ls.intersection(PDXM.columns)
79
ls = ls.intersection(PDXC.columns)
80
ls2 = GDSCE.index.intersection(GDSCM.index)
81
ls2 = ls2.intersection(GDSCC.index)
82
ls3 = PDXE.index.intersection(PDXM.index)
83
ls3 = ls3.intersection(PDXC.index)
84
ls = pd.unique(ls)
85
86
PDXE = PDXE.loc[ls3,ls]
87
PDXM = PDXM.loc[ls3,ls]
88
PDXC = PDXC.loc[ls3,ls]
89
GDSCE = GDSCE.loc[ls2,ls]
90
GDSCM = GDSCM.loc[ls2,ls]
91
GDSCC = GDSCC.loc[ls2,ls]
92
93
GDSCR.loc[GDSCR.iloc[:,0] == 'R'] = 0
94
GDSCR.loc[GDSCR.iloc[:,0] == 'S'] = 1
95
GDSCR.columns = ['targets']
96
GDSCR = GDSCR.loc[ls2,:]
97
98
PDXR = pd.read_csv("PDX_response.Paclitaxel.tsv", 
99
                       sep = "\t", index_col=0, decimal = ",")
100
PDXR.loc[PDXR.iloc[:,0] == 'R'] = 0
101
PDXR.loc[PDXR.iloc[:,0] == 'S'] = 1
102
103
104
ls_mb_size = [13, 36, 64]
105
ls_h_dim = [1023, 512, 256, 128, 64, 32, 16]
106
ls_marg = [0.5, 1, 1.5, 2, 2.5]
107
ls_lr = [0.5, 0.1, 0.05, 0.01, 0.001, 0.005, 0.0005, 0.0001,0.00005, 0.00001]
108
ls_epoch = [20, 50, 10, 15, 30, 40, 60, 70, 80, 90, 100]
109
ls_rate = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
110
ls_wd = [0.01, 0.001, 0.1, 0.0001]
111
ls_lam = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
112
113
Y = GDSCR['targets'].values
114
115
skf = StratifiedKFold(n_splits=5, random_state=42)
116
    
117
for iters in range(max_iter):
118
    k = 0
119
    mbs = random.choice(ls_mb_size)
120
    hdm1 = random.choice(ls_h_dim)
121
    hdm2 = random.choice(ls_h_dim)
122
    hdm3 = random.choice(ls_h_dim) 
123
    mrg = random.choice(ls_marg)
124
    lre = random.choice(ls_lr)
125
    lrm = random.choice(ls_lr)
126
    lrc = random.choice(ls_lr)
127
    lrCL = random.choice(ls_lr)
128
    epch = random.choice(ls_epoch)
129
    rate1 = random.choice(ls_rate)
130
    rate2 = random.choice(ls_rate)
131
    rate3 = random.choice(ls_rate)
132
    rate4 = random.choice(ls_rate)    
133
    wd = random.choice(ls_wd)   
134
    lam = random.choice(ls_lam)   
135
136
    for train_index, test_index in skf.split(GDSCE.values, Y):
137
        k = k + 1
138
        X_trainE = GDSCE.values[train_index,:]
139
        X_testE =  GDSCE.values[test_index,:]
140
        X_trainM = GDSCM.values[train_index,:]
141
        X_testM = GDSCM.values[test_index,:]
142
        X_trainC = GDSCC.values[train_index,:]
143
        X_testC = GDSCM.values[test_index,:]
144
        y_trainE = Y[train_index]
145
        y_testE = Y[test_index]
146
        
147
        scalerGDSC = sk.StandardScaler()
148
        scalerGDSC.fit(X_trainE)
149
        X_trainE = scalerGDSC.transform(X_trainE)
150
        X_testE = scalerGDSC.transform(X_testE)
151
152
        X_trainM = np.nan_to_num(X_trainM)
153
        X_trainC = np.nan_to_num(X_trainC)
154
        X_testM = np.nan_to_num(X_testM)
155
        X_testC = np.nan_to_num(X_testC)
156
        
157
        TX_testE = torch.FloatTensor(X_testE)
158
        TX_testM = torch.FloatTensor(X_testM)
159
        TX_testC = torch.FloatTensor(X_testC)
160
        ty_testE = torch.FloatTensor(y_testE.astype(int))
161
        
162
        #Train
163
        class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
164
        weight = 1. / class_sample_count
165
        samples_weight = np.array([weight[t] for t in y_trainE])
166
167
        samples_weight = torch.from_numpy(samples_weight)
168
        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)
169
170
        mb_size = mbs
171
172
        trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(X_trainM), 
173
                                                      torch.FloatTensor(X_trainC), torch.FloatTensor(y_trainE.astype(int)))
174
175
        trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)
176
177
        n_sampE, IE_dim = X_trainE.shape
178
        n_sampM, IM_dim = X_trainM.shape
179
        n_sampC, IC_dim = X_trainC.shape
180
181
        h_dim1 = hdm1
182
        h_dim2 = hdm2
183
        h_dim3 = hdm3        
184
        Z_in = h_dim1 + h_dim2 + h_dim3
185
        marg = mrg
186
        lrE = lre
187
        lrM = lrm
188
        lrC = lrc
189
        epoch = epch
190
191
        costtr = []
192
        auctr = []
193
        costts = []
194
        aucts = []
195
196
        triplet_selector = RandomNegativeTripletSelector(marg)
197
        triplet_selector2 = AllTripletSelector()
198
199
        class AEE(nn.Module):
200
            def __init__(self):
201
                super(AEE, self).__init__()
202
                self.EnE = torch.nn.Sequential(
203
                    nn.Linear(IE_dim, h_dim1),
204
                    nn.BatchNorm1d(h_dim1),
205
                    nn.ReLU(),
206
                    nn.Dropout(rate1))
207
            def forward(self, x):
208
                output = self.EnE(x)
209
                return output
210
211
        class AEM(nn.Module):
212
            def __init__(self):
213
                super(AEM, self).__init__()
214
                self.EnM = torch.nn.Sequential(
215
                    nn.Linear(IM_dim, h_dim2),
216
                    nn.BatchNorm1d(h_dim2),
217
                    nn.ReLU(),
218
                    nn.Dropout(rate2))
219
            def forward(self, x):
220
                output = self.EnM(x)
221
                return output    
222
223
224
        class AEC(nn.Module):
225
            def __init__(self):
226
                super(AEC, self).__init__()
227
                self.EnC = torch.nn.Sequential(
228
                    nn.Linear(IM_dim, h_dim3),
229
                    nn.BatchNorm1d(h_dim3),
230
                    nn.ReLU(),
231
                    nn.Dropout(rate3))
232
            def forward(self, x):
233
                output = self.EnC(x)
234
                return output    
235
236
        class OnlineTriplet(nn.Module):
237
            def __init__(self, marg, triplet_selector):
238
                super(OnlineTriplet, self).__init__()
239
                self.marg = marg
240
                self.triplet_selector = triplet_selector
241
            def forward(self, embeddings, target):
242
                triplets = self.triplet_selector.get_triplets(embeddings, target)
243
                return triplets
244
245
        class OnlineTestTriplet(nn.Module):
246
            def __init__(self, marg, triplet_selector):
247
                super(OnlineTestTriplet, self).__init__()
248
                self.marg = marg
249
                self.triplet_selector = triplet_selector
250
            def forward(self, embeddings, target):
251
                triplets = self.triplet_selector.get_triplets(embeddings, target)
252
                return triplets    
253
254
        class Classifier(nn.Module):
255
            def __init__(self):
256
                super(Classifier, self).__init__()
257
                self.FC = torch.nn.Sequential(
258
                    nn.Linear(Z_in, 1),
259
                    nn.Dropout(rate4),
260
                    nn.Sigmoid())
261
            def forward(self, x):
262
                return self.FC(x)
263
264
        torch.cuda.manual_seed_all(42)
265
266
        AutoencoderE = AEE()
267
        AutoencoderM = AEM()
268
        AutoencoderC = AEC()
269
270
        solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
271
        solverM = optim.Adagrad(AutoencoderM.parameters(), lr=lrM)
272
        solverC = optim.Adagrad(AutoencoderC.parameters(), lr=lrC)
273
274
        trip_criterion = torch.nn.TripletMarginLoss(margin=marg, p=2)
275
        TripSel = OnlineTriplet(marg, triplet_selector)
276
        TripSel2 = OnlineTestTriplet(marg, triplet_selector2)
277
278
        Clas = Classifier()
279
        SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
280
        C_loss = torch.nn.BCELoss()
281
282
        for it in range(epoch):
283
284
            epoch_cost4 = 0
285
            epoch_cost3 = []
286
            num_minibatches = int(n_sampE / mb_size) 
287
288
            for i, (dataE, dataM, dataC, target) in enumerate(trainLoader):
289
                flag = 0
290
                AutoencoderE.train()
291
                AutoencoderM.train()
292
                AutoencoderC.train()
293
                Clas.train()
294
295
                if torch.mean(target)!=0. and torch.mean(target)!=1.: 
296
                    ZEX = AutoencoderE(dataE)
297
                    ZMX = AutoencoderM(dataM)
298
                    ZCX = AutoencoderC(dataC)
299
300
                    ZT = torch.cat((ZEX, ZMX, ZCX), 1)
301
                    ZT = F.normalize(ZT, p=2, dim=0)
302
                    Pred = Clas(ZT)
303
304
                    Triplets = TripSel2(ZT, target)
305
                    loss = lam * trip_criterion(ZT[Triplets[:,0],:],ZT[Triplets[:,1],:],ZT[Triplets[:,2],:]) + C_loss(Pred,target.view(-1,1))     
306
307
                    y_true = target.view(-1,1)
308
                    y_pred = Pred
309
                    AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 
310
311
                    solverE.zero_grad()
312
                    solverM.zero_grad()
313
                    solverC.zero_grad()
314
                    SolverClass.zero_grad()
315
316
                    loss.backward()
317
318
                    solverE.step()
319
                    solverM.step()
320
                    solverC.step()
321
                    SolverClass.step()
322
323
                    epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
324
                    epoch_cost3.append(AUC)
325
                    flag = 1
326
327
            if flag == 1:
328
                costtr.append(torch.mean(epoch_cost4))
329
                auctr.append(np.mean(epoch_cost3))
330
                print('Iter-{}; Total loss: {:.4}'.format(it, loss))
331
332
            with torch.no_grad():
333
334
                AutoencoderE.eval()
335
                AutoencoderM.eval()
336
                AutoencoderC.eval()
337
                Clas.eval()
338
339
                ZET = AutoencoderE(TX_testE)
340
                ZMT = AutoencoderM(TX_testM)
341
                ZCT = AutoencoderC(TX_testC)
342
343
                ZTT = torch.cat((ZET, ZMT, ZCT), 1)
344
                ZTT = F.normalize(ZTT, p=2, dim=0)
345
                PredT = Clas(ZTT)
346
347
                TripletsT = TripSel2(ZTT, ty_testE)
348
                lossT = lam * trip_criterion(ZTT[TripletsT[:,0],:], ZTT[TripletsT[:,1],:], ZTT[TripletsT[:,2],:]) + C_loss(PredT,ty_testE.view(-1,1))
349
350
                y_truet = ty_testE.view(-1,1)
351
                y_predt = PredT
352
                AUCt = roc_auc_score(y_truet.detach().numpy(),y_predt.detach().numpy())        
353
354
                costts.append(lossT)
355
                aucts.append(AUCt)
356
357
        plt.plot(np.squeeze(costtr), '-r',np.squeeze(costts), '-b')
358
        plt.ylabel('Total cost')
359
        plt.xlabel('iterations (per tens)')
360
361
        title = 'Cost Paclitaxel iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
362
                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)
363
364
        plt.suptitle(title)
365
        plt.savefig(save_results_to + title + '.png', dpi = 150)
366
        plt.close()
367
368
        plt.plot(np.squeeze(auctr), '-r',np.squeeze(aucts), '-b')
369
        plt.ylabel('AUC')
370
        plt.xlabel('iterations (per tens)')
371
372
        title = 'AUC Paclitaxel iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
373
                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)        
374
375
        plt.suptitle(title)
376
        plt.savefig(save_results_to + title + '.png', dpi = 150)
377
        plt.close()