a b/Cross validation/MOLI Complete/Cetuximab_cvSoftTripletClassifierNetv16_Script.py
1
import torch 
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import torch.optim as optim
5
import numpy as np
6
import matplotlib
7
matplotlib.use('Agg')
8
import matplotlib.pyplot as plt
9
import matplotlib.gridspec as gridspec
10
import pandas as pd
11
import math
12
import sklearn.preprocessing as sk
13
import seaborn as sns
14
from sklearn import metrics
15
from sklearn.feature_selection import VarianceThreshold
16
from sklearn.model_selection import train_test_split
17
from utils import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
18
from metrics import AverageNonzeroTripletsMetric
19
from torch.utils.data.sampler import WeightedRandomSampler
20
from sklearn.metrics import roc_auc_score
21
from sklearn.metrics import average_precision_score
22
import random
23
from random import randint
24
from sklearn.model_selection import StratifiedKFold
25
26
save_results_to = '/home/hnoghabi/SoftClassifierTripNetv16/Cetuximab/'
27
torch.manual_seed(42)
28
29
max_iter = 100
30
31
GDSCE = pd.read_csv("GDSC_exprs.Cetuximab.eb_with.PDX_exprs.Cetuximab.tsv", 
32
                    sep = "\t", index_col=0, decimal = ",")
33
GDSCE = pd.DataFrame.transpose(GDSCE)
34
35
GDSCR = pd.read_csv("GDSC_response.Cetuximab.tsv", 
36
                    sep = "\t", index_col=0, decimal = ",")
37
38
PDXE = pd.read_csv("PDX_exprs.Cetuximab.eb_with.GDSC_exprs.Cetuximab.tsv", 
39
                   sep = "\t", index_col=0, decimal = ",")
40
PDXE = pd.DataFrame.transpose(PDXE)
41
42
PDXM = pd.read_csv("PDX_mutations.Cetuximab.tsv", 
43
                   sep = "\t", index_col=0, decimal = ".")
44
PDXM = pd.DataFrame.transpose(PDXM)
45
46
PDXC = pd.read_csv("PDX_CNA.Cetuximab.tsv", 
47
                   sep = "\t", index_col=0, decimal = ".")
48
PDXC.drop_duplicates(keep='last')
49
PDXC = pd.DataFrame.transpose(PDXC)
50
51
GDSCM = pd.read_csv("GDSC_mutations.Cetuximab.tsv", 
52
                    sep = "\t", index_col=0, decimal = ".")
53
GDSCM = pd.DataFrame.transpose(GDSCM)
54
55
56
GDSCC = pd.read_csv("GDSC_CNA.Cetuximab.tsv", 
57
                    sep = "\t", index_col=0, decimal = ".")
58
GDSCC.drop_duplicates(keep='last')
59
GDSCC = pd.DataFrame.transpose(GDSCC)
60
61
selector = VarianceThreshold(0.05)
62
selector.fit_transform(GDSCE)
63
GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]
64
65
PDXC = PDXC.fillna(0)
66
PDXC[PDXC != 0.0] = 1
67
PDXM = PDXM.fillna(0)
68
PDXM[PDXM != 0.0] = 1
69
GDSCM = GDSCM.fillna(0)
70
GDSCM[GDSCM != 0.0] = 1
71
GDSCC = GDSCC.fillna(0)
72
GDSCC[GDSCC != 0.0] = 1
73
74
ls = GDSCE.columns.intersection(GDSCM.columns)
75
ls = ls.intersection(GDSCC.columns)
76
ls = ls.intersection(PDXE.columns)
77
ls = ls.intersection(PDXM.columns)
78
ls = ls.intersection(PDXC.columns)
79
ls2 = GDSCE.index.intersection(GDSCM.index)
80
ls2 = ls2.intersection(GDSCC.index)
81
ls3 = PDXE.index.intersection(PDXM.index)
82
ls3 = ls3.intersection(PDXC.index)
83
ls = pd.unique(ls)
84
85
PDXE = PDXE.loc[ls3,ls]
86
PDXM = PDXM.loc[ls3,ls]
87
PDXC = PDXC.loc[ls3,ls]
88
GDSCE = GDSCE.loc[ls2,ls]
89
GDSCM = GDSCM.loc[ls2,ls]
90
GDSCC = GDSCC.loc[ls2,ls]
91
92
GDSCR.loc[GDSCR.iloc[:,0] == 'R'] = 0
93
GDSCR.loc[GDSCR.iloc[:,0] == 'S'] = 1
94
GDSCR.columns = ['targets']
95
GDSCR = GDSCR.loc[ls2,:]
96
97
ls_mb_size = [14, 30, 64]
98
ls_h_dim = [1024, 512, 256, 128, 64]
99
ls_marg = [0.5, 1, 1.5, 2, 2.5]
100
ls_lr = [0.0005, 0.0001, 0.005, 0.001]
101
ls_epoch = [20, 50, 10, 15, 30, 40, 60, 70, 80, 90, 100]
102
ls_rate = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
103
ls_wd = [0.01, 0.001, 0.1, 0.0001]
104
ls_lam = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
105
106
Y = GDSCR['targets'].values
107
108
skf = StratifiedKFold(n_splits=5, random_state=42)
109
110
for iters in range(max_iter):
111
    k = 0
112
    mbs = 30
113
    hdm1 = random.choice(ls_h_dim)
114
    hdm2 = random.choice(ls_h_dim)
115
    hdm3 = random.choice(ls_h_dim) 
116
    mrg = random.choice(ls_marg)
117
    lre = random.choice(ls_lr)
118
    lrm = random.choice(ls_lr)
119
    lrc = random.choice(ls_lr)
120
    lrCL = random.choice(ls_lr)
121
    epch = random.choice(ls_epoch)
122
    rate1 = random.choice(ls_rate)
123
    rate2 = random.choice(ls_rate)
124
    rate3 = random.choice(ls_rate)
125
    rate4 = random.choice(ls_rate)    
126
    wd = random.choice(ls_wd)   
127
    lam = random.choice(ls_lam)   
128
129
    for train_index, test_index in skf.split(GDSCE.values, Y):
130
        k = k + 1
131
        X_trainE = GDSCE.values[train_index,:]
132
        X_testE =  GDSCE.values[test_index,:]
133
        X_trainM = GDSCM.values[train_index,:]
134
        X_testM = GDSCM.values[test_index,:]
135
        X_trainC = GDSCC.values[train_index,:]
136
        X_testC = GDSCM.values[test_index,:]
137
        y_trainE = Y[train_index]
138
        y_testE = Y[test_index]
139
        
140
        scalerGDSC = sk.StandardScaler()
141
        scalerGDSC.fit(X_trainE)
142
        X_trainE = scalerGDSC.transform(X_trainE)
143
        X_testE = scalerGDSC.transform(X_testE)
144
145
        X_trainM = np.nan_to_num(X_trainM)
146
        X_trainC = np.nan_to_num(X_trainC)
147
        X_testM = np.nan_to_num(X_testM)
148
        X_testC = np.nan_to_num(X_testC)
149
        
150
        TX_testE = torch.FloatTensor(X_testE)
151
        TX_testM = torch.FloatTensor(X_testM)
152
        TX_testC = torch.FloatTensor(X_testC)
153
        ty_testE = torch.FloatTensor(y_testE.astype(int))
154
        
155
        #Train
156
        class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
157
        weight = 1. / class_sample_count
158
        samples_weight = np.array([weight[t] for t in y_trainE])
159
160
        samples_weight = torch.from_numpy(samples_weight)
161
        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)
162
163
        mb_size = mbs
164
165
        trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(X_trainM), 
166
                                                      torch.FloatTensor(X_trainC), torch.FloatTensor(y_trainE.astype(int)))
167
168
        trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)
169
170
        n_sampE, IE_dim = X_trainE.shape
171
        n_sampM, IM_dim = X_trainM.shape
172
        n_sampC, IC_dim = X_trainC.shape
173
174
        h_dim1 = hdm1
175
        h_dim2 = hdm2
176
        h_dim3 = hdm3        
177
        Z_in = h_dim1 + h_dim2 + h_dim3
178
        marg = mrg
179
        lrE = lre
180
        lrM = lrm
181
        lrC = lrc
182
        epoch = epch
183
184
        costtr = []
185
        auctr = []
186
        costts = []
187
        aucts = []
188
189
        triplet_selector = RandomNegativeTripletSelector(marg)
190
        triplet_selector2 = AllTripletSelector()
191
192
        class AEE(nn.Module):
193
            def __init__(self):
194
                super(AEE, self).__init__()
195
                self.EnE = torch.nn.Sequential(
196
                    nn.Linear(IE_dim, h_dim1),
197
                    nn.BatchNorm1d(h_dim1),
198
                    nn.ReLU(),
199
                    nn.Dropout(rate1))
200
            def forward(self, x):
201
                output = self.EnE(x)
202
                return output
203
204
        class AEM(nn.Module):
205
            def __init__(self):
206
                super(AEM, self).__init__()
207
                self.EnM = torch.nn.Sequential(
208
                    nn.Linear(IM_dim, h_dim2),
209
                    nn.BatchNorm1d(h_dim2),
210
                    nn.ReLU(),
211
                    nn.Dropout(rate2))
212
            def forward(self, x):
213
                output = self.EnM(x)
214
                return output    
215
216
217
        class AEC(nn.Module):
218
            def __init__(self):
219
                super(AEC, self).__init__()
220
                self.EnC = torch.nn.Sequential(
221
                    nn.Linear(IM_dim, h_dim3),
222
                    nn.BatchNorm1d(h_dim3),
223
                    nn.ReLU(),
224
                    nn.Dropout(rate3))
225
            def forward(self, x):
226
                output = self.EnC(x)
227
                return output    
228
229
        class OnlineTriplet(nn.Module):
230
            def __init__(self, marg, triplet_selector):
231
                super(OnlineTriplet, self).__init__()
232
                self.marg = marg
233
                self.triplet_selector = triplet_selector
234
            def forward(self, embeddings, target):
235
                triplets = self.triplet_selector.get_triplets(embeddings, target)
236
                return triplets
237
238
        class OnlineTestTriplet(nn.Module):
239
            def __init__(self, marg, triplet_selector):
240
                super(OnlineTestTriplet, self).__init__()
241
                self.marg = marg
242
                self.triplet_selector = triplet_selector
243
            def forward(self, embeddings, target):
244
                triplets = self.triplet_selector.get_triplets(embeddings, target)
245
                return triplets    
246
247
        class Classifier(nn.Module):
248
            def __init__(self):
249
                super(Classifier, self).__init__()
250
                self.FC = torch.nn.Sequential(
251
                    nn.Linear(Z_in, 1),
252
                    nn.Dropout(rate4),
253
                    nn.Sigmoid())
254
            def forward(self, x):
255
                return self.FC(x)
256
257
        torch.cuda.manual_seed_all(42)
258
259
        AutoencoderE = AEE()
260
        AutoencoderM = AEM()
261
        AutoencoderC = AEC()
262
263
        solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
264
        solverM = optim.Adagrad(AutoencoderM.parameters(), lr=lrM)
265
        solverC = optim.Adagrad(AutoencoderC.parameters(), lr=lrC)
266
267
        trip_criterion = torch.nn.TripletMarginLoss(margin=marg, p=2)
268
        TripSel = OnlineTriplet(marg, triplet_selector)
269
        TripSel2 = OnlineTestTriplet(marg, triplet_selector2)
270
271
        Clas = Classifier()
272
        SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
273
        C_loss = torch.nn.BCELoss()
274
275
        for it in range(epoch):
276
277
            epoch_cost4 = 0
278
            epoch_cost3 = []
279
            num_minibatches = int(n_sampE / mb_size) 
280
281
            for i, (dataE, dataM, dataC, target) in enumerate(trainLoader):
282
                flag = 0
283
                AutoencoderE.train()
284
                AutoencoderM.train()
285
                AutoencoderC.train()
286
                Clas.train()
287
288
                if torch.mean(target)!=0. and torch.mean(target)!=1.: 
289
                    ZEX = AutoencoderE(dataE)
290
                    ZMX = AutoencoderM(dataM)
291
                    ZCX = AutoencoderC(dataC)
292
293
                    ZT = torch.cat((ZEX, ZMX, ZCX), 1)
294
                    ZT = F.normalize(ZT, p=2, dim=0)
295
                    Pred = Clas(ZT)
296
297
                    Triplets = TripSel2(ZT, target)
298
                    loss = lam * trip_criterion(ZT[Triplets[:,0],:],ZT[Triplets[:,1],:],ZT[Triplets[:,2],:]) + C_loss(Pred,target.view(-1,1))     
299
300
                    y_true = target.view(-1,1)
301
                    y_pred = Pred
302
                    AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 
303
304
                    solverE.zero_grad()
305
                    solverM.zero_grad()
306
                    solverC.zero_grad()
307
                    SolverClass.zero_grad()
308
309
                    loss.backward()
310
311
                    solverE.step()
312
                    solverM.step()
313
                    solverC.step()
314
                    SolverClass.step()
315
316
                    epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
317
                    epoch_cost3.append(AUC)
318
                    flag = 1
319
320
            if flag == 1:
321
                costtr.append(torch.mean(epoch_cost4))
322
                auctr.append(np.mean(epoch_cost3))
323
                print('Iter-{}; Total loss: {:.4}'.format(it, loss))
324
325
            with torch.no_grad():
326
327
                AutoencoderE.eval()
328
                AutoencoderM.eval()
329
                AutoencoderC.eval()
330
                Clas.eval()
331
332
                ZET = AutoencoderE(TX_testE)
333
                ZMT = AutoencoderM(TX_testM)
334
                ZCT = AutoencoderC(TX_testC)
335
336
                ZTT = torch.cat((ZET, ZMT, ZCT), 1)
337
                ZTT = F.normalize(ZTT, p=2, dim=0)
338
                PredT = Clas(ZTT)
339
340
                TripletsT = TripSel2(ZTT, ty_testE)
341
                lossT = lam * trip_criterion(ZTT[TripletsT[:,0],:], ZTT[TripletsT[:,1],:], ZTT[TripletsT[:,2],:]) + C_loss(PredT,ty_testE.view(-1,1))
342
343
                y_truet = ty_testE.view(-1,1)
344
                y_predt = PredT
345
                AUCt = roc_auc_score(y_truet.detach().numpy(),y_predt.detach().numpy())        
346
347
                costts.append(lossT)
348
                aucts.append(AUCt)
349
350
        plt.plot(np.squeeze(costtr), '-r',np.squeeze(costts), '-b')
351
        plt.ylabel('Total cost')
352
        plt.xlabel('iterations (per tens)')
353
354
        title = 'Cost Cetuximab iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
355
                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)
356
357
        plt.suptitle(title)
358
        plt.savefig(save_results_to + title + '.png', dpi = 150)
359
        plt.close()
360
361
        plt.plot(np.squeeze(auctr), '-r',np.squeeze(aucts), '-b')
362
        plt.ylabel('AUC')
363
        plt.xlabel('iterations (per tens)')
364
365
        title = 'AUC Cetuximab iter = {}, fold = {}, mb_size = {},  h_dim[1,2,3] = ({},{},{}), marg = {}, lr[E,M,C] = ({}, {}, {}), epoch = {}, rate[1,2,3,4] = ({},{},{},{}), wd = {}, lrCL = {}, lam = {}'.\
366
                      format(iters, k, mbs, hdm1, hdm2, hdm3, mrg, lre, lrm, lrc, epch, rate1, rate2, rate3, rate4, wd, lrCL, lam)        
367
368
        plt.suptitle(title)
369
        plt.savefig(save_results_to + title + '.png', dpi = 150)
370
        plt.close()