a b/Cross validation/MOLI only expression/CisplatinTCGA_OnlyExprsv2_Script.py
1
import torch 
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import torch.optim as optim
5
import numpy as np
6
import matplotlib
7
matplotlib.use('Agg')
8
import matplotlib.pyplot as plt
9
import matplotlib.gridspec as gridspec
10
import pandas as pd
11
import math
12
import sklearn.preprocessing as sk
13
import seaborn as sns
14
from sklearn import metrics
15
from sklearn.feature_selection import VarianceThreshold
16
from sklearn.model_selection import train_test_split
17
from utils import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
18
from metrics import AverageNonzeroTripletsMetric
19
from torch.utils.data.sampler import WeightedRandomSampler
20
from sklearn.metrics import roc_auc_score
21
from sklearn.metrics import average_precision_score
22
import random
23
from random import randint
24
from sklearn.model_selection import StratifiedKFold
25
26
save_results_to = '/home/hnoghabi/OnlyExprsv2/Cisplatin/'
27
torch.manual_seed(42)
28
random.seed(42)
29
30
max_iter = 50
31
32
GDSCE = pd.read_csv("GDSC_exprs.Cisplatin.eb_with.TCGA_exprs.Cisplatin.tsv", 
33
                    sep = "\t", index_col=0, decimal = ",")
34
GDSCE = pd.DataFrame.transpose(GDSCE)
35
36
TCGAE = pd.read_csv("TCGA_exprs.Cisplatin.eb_with.GDSC_exprs.Cisplatin.tsv", 
37
                   sep = "\t", index_col=0, decimal = ",")
38
TCGAE = pd.DataFrame.transpose(TCGAE)
39
40
TCGAM = pd.read_csv("TCGA_mutations.Cisplatin.tsv", 
41
                   sep = "\t", index_col=0, decimal = ".")
42
TCGAM = pd.DataFrame.transpose(TCGAM)
43
TCGAM = TCGAM.loc[:,~TCGAM.columns.duplicated()]
44
45
TCGAC = pd.read_csv("TCGA_CNA.Cisplatin.tsv", 
46
                   sep = "\t", index_col=0, decimal = ".")
47
TCGAC = pd.DataFrame.transpose(TCGAC)
48
TCGAC = TCGAC.loc[:,~TCGAC.columns.duplicated()]
49
50
GDSCM = pd.read_csv("GDSC_mutations.Cisplatin.tsv", 
51
                    sep = "\t", index_col=0, decimal = ".")
52
GDSCM = pd.DataFrame.transpose(GDSCM)
53
GDSCM = GDSCM.loc[:,~GDSCM.columns.duplicated()]
54
55
GDSCC = pd.read_csv("GDSC_CNA.Cisplatin.tsv", 
56
                    sep = "\t", index_col=0, decimal = ".")
57
GDSCC.drop_duplicates(keep='last')
58
GDSCC = pd.DataFrame.transpose(GDSCC)
59
GDSCC = GDSCC.loc[:,~GDSCC.columns.duplicated()]
60
61
selector = VarianceThreshold(0.05)
62
selector.fit_transform(GDSCE)
63
GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]
64
65
TCGAC = TCGAC.fillna(0)
66
TCGAC[TCGAC != 0.0] = 1
67
TCGAM = TCGAM.fillna(0)
68
TCGAM[TCGAM != 0.0] = 1
69
GDSCM = GDSCM.fillna(0)
70
GDSCM[GDSCM != 0.0] = 1
71
GDSCC = GDSCC.fillna(0)
72
GDSCC[GDSCC != 0.0] = 1
73
74
ls = set(GDSCE.columns.values).intersection(set(GDSCM.columns.values))
75
ls = set(ls).intersection(set(GDSCC.columns.values))
76
ls = set(ls).intersection(TCGAE.columns)
77
ls = set(ls).intersection(TCGAM.columns)
78
ls = set(ls).intersection(set(TCGAC.columns.values))
79
ls2 = set(GDSCE.index.values).intersection(set(GDSCM.index.values))
80
ls2 = set(ls2).intersection(set(GDSCC.index.values))
81
ls3 = set(TCGAE.index.values).intersection(set(TCGAM.index.values))
82
ls3 = set(ls3).intersection(set(TCGAC.index.values))
83
#ls = pd.unique(ls)
84
85
TCGAE = TCGAE.loc[ls3,ls]
86
TCGAM = TCGAM.loc[ls3,ls]
87
TCGAC = TCGAC.loc[ls3,ls]
88
GDSCE = GDSCE.loc[ls2,ls]
89
GDSCM = GDSCM.loc[ls2,ls]
90
GDSCC = GDSCC.loc[ls2,ls]
91
92
GDSCR = pd.read_csv("GDSC_response.Cisplatin.tsv", 
93
                    sep = "\t", index_col=0, decimal = ",")
94
TCGAR = pd.read_csv("TCGA_response.Cisplatin.tsv", 
95
                       sep = "\t", index_col=0, decimal = ",")
96
97
GDSCR.rename(mapper = str, axis = 'index', inplace = True)
98
GDSCR = GDSCR.loc[ls2,:]
99
#GDSCR.loc[GDSCR.iloc[:,0] == 'R','response'] = 0
100
#GDSCR.loc[GDSCR.iloc[:,0] == 'S','response'] = 1
101
102
TCGAR = TCGAR.loc[ls3,:]
103
#TCGAR.loc[TCGAR.iloc[:,1] == 'R','response'] = 0
104
#TCGAR.loc[TCGAR.iloc[:,1] == 'S','response'] = 1
105
106
d = {"R":0,"S":1}
107
GDSCR["response"] = GDSCR.loc[:,"response"].apply(lambda x: d[x])
108
TCGAR["response"] = TCGAR.loc[:,"response"].apply(lambda x: d[x])
109
110
Y = GDSCR['response'].values
111
#y_test = TCGAR['response'].values
112
113
ls_mb_size = [13, 36, 64]
114
ls_h_dim = [1024, 256, 128, 512, 64, 16]
115
#ls_h_dim = [32, 16, 8, 4]
116
ls_marg = [0.5, 1, 1.5, 2, 2.5, 3]
117
ls_lr = [0.5, 0.1, 0.05, 0.01, 0.001, 0.005, 0.0005, 0.0001,0.00005, 0.00001]
118
ls_epoch = [20, 50, 90, 100]
119
ls_rate = [0.3, 0.4, 0.5]
120
ls_wd = [0.1, 0.001, 0.0001]
121
ls_lam = [0.1, 0.5, 0.01, 0.05, 0.001, 0.005]
122
123
#ls_mb_size = [36, 70]
124
#ls_h_dim = [1024, 256, 128, 64]
125
#ls_h_dim = [32, 16, 8, 4]
126
#ls_marg = [0.5, 1, 2, 2.5]
127
#ls_lr = [0.5, 0.1, 0.05, 0.01, 0.001, 0.005, 0.0005, 0.0001,0.00005, 0.00001]
128
#ls_lr1 = [0.1, 0.05, 0.01, 0.005]
129
#ls_lr2 = [0.5, 0.001, 0.005, 0.00005]
130
#ls_lr3 = [0.5, 0.01, 0.001, 0.005, 0.0005, 0.0001,]
131
#ls_lr4 = [0.01, 0.001, 0.00001]
132
#ls_epoch = [10, 50, 20]
133
#ls_rate = [0.3, 0.4, 0.5]
134
#ls_wd = [0.001, 0.0001]
135
136
skf = StratifiedKFold(n_splits=5, random_state=42)
137
    
138
for iters in range(max_iter):
139
    k = 0
140
    mbs = random.choice(ls_mb_size)
141
    hdm = random.choice(ls_h_dim)
142
    mrg = random.choice(ls_marg)
143
    lre = random.choice(ls_lr)
144
    lrCL = random.choice(ls_lr)
145
    epch = random.choice(ls_epoch)
146
    rate = random.choice(ls_rate)
147
    wd = random.choice(ls_wd)   
148
    lam = random.choice(ls_lam)       
149
150
    for train_index, test_index in skf.split(GDSCE.values, Y):
151
        k = k + 1
152
        X_trainE = GDSCE.values[train_index,:]
153
        X_testE =  GDSCE.values[test_index,:]
154
        y_trainE = Y[train_index]
155
        y_testE = Y[test_index]
156
        
157
        scalerGDSC = sk.StandardScaler()
158
        scalerGDSC.fit(X_trainE)
159
        X_trainE = scalerGDSC.transform(X_trainE)
160
        X_testE = scalerGDSC.transform(X_testE)
161
        
162
        TX_testE = torch.FloatTensor(X_testE)
163
        ty_testE = torch.FloatTensor(y_testE.astype(int))
164
        
165
        #Train
166
        class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
167
        weight = 1. / class_sample_count
168
        samples_weight = np.array([weight[t] for t in y_trainE])
169
170
        samples_weight = torch.from_numpy(samples_weight)
171
        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)
172
173
        mb_size = mbs
174
175
        trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(y_trainE.astype(int)))
176
177
        trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)
178
179
        n_sampE, IE_dim = X_trainE.shape
180
181
        h_dim = hdm
182
        Z_in = h_dim
183
        marg = mrg
184
        lrE = lre
185
        epoch = epch
186
187
        costtr = []
188
        auctr = []
189
        costts = []
190
        aucts = []
191
192
        triplet_selector = RandomNegativeTripletSelector(marg)
193
        triplet_selector2 = AllTripletSelector()
194
195
        class AEE(nn.Module):
196
            def __init__(self):
197
                super(AEE, self).__init__()
198
                self.EnE = torch.nn.Sequential(
199
                    nn.Linear(IE_dim, h_dim),
200
                    nn.BatchNorm1d(h_dim),
201
                    nn.ReLU(),
202
                    nn.Dropout())
203
            def forward(self, x):
204
                output = self.EnE(x)
205
                return output  
206
207
        class OnlineTriplet(nn.Module):
208
            def __init__(self, marg, triplet_selector):
209
                super(OnlineTriplet, self).__init__()
210
                self.marg = marg
211
                self.triplet_selector = triplet_selector
212
            def forward(self, embeddings, target):
213
                triplets = self.triplet_selector.get_triplets(embeddings, target)
214
                return triplets
215
216
        class OnlineTestTriplet(nn.Module):
217
            def __init__(self, marg, triplet_selector):
218
                super(OnlineTestTriplet, self).__init__()
219
                self.marg = marg
220
                self.triplet_selector = triplet_selector
221
            def forward(self, embeddings, target):
222
                triplets = self.triplet_selector.get_triplets(embeddings, target)
223
                return triplets    
224
225
        class Classifier(nn.Module):
226
            def __init__(self):
227
                super(Classifier, self).__init__()
228
                self.FC = torch.nn.Sequential(
229
                    nn.Linear(Z_in, 1),
230
                    nn.Dropout(rate),
231
                    nn.Sigmoid())
232
            def forward(self, x):
233
                return self.FC(x)
234
235
        torch.cuda.manual_seed_all(42)
236
237
        AutoencoderE = AEE()
238
239
240
        solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
241
242
        trip_criterion = torch.nn.TripletMarginLoss(margin=marg, p=2)
243
        TripSel = OnlineTriplet(marg, triplet_selector)
244
        TripSel2 = OnlineTestTriplet(marg, triplet_selector2)
245
246
        Clas = Classifier()
247
        SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
248
        C_loss = torch.nn.BCELoss()
249
250
        for it in range(epoch):
251
252
            epoch_cost4 = 0
253
            epoch_cost3 = []
254
            num_minibatches = int(n_sampE / mb_size) 
255
256
            for i, (dataE, target) in enumerate(trainLoader):
257
                flag = 0
258
                AutoencoderE.train()
259
260
                Clas.train()
261
262
                if torch.mean(target)!=0. and torch.mean(target)!=1.: 
263
                    ZEX = AutoencoderE(dataE)
264
                    Pred = Clas(ZEX)
265
266
                    Triplets = TripSel2(ZEX, target)
267
                    loss = lam * trip_criterion(ZEX[Triplets[:,0],:],ZEX[Triplets[:,1],:],ZEX[Triplets[:,2],:]) + C_loss(Pred,target.view(-1,1))     
268
269
                    y_true = target.view(-1,1)
270
                    y_pred = Pred
271
                    AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 
272
273
                    solverE.zero_grad()
274
                    SolverClass.zero_grad()
275
276
                    loss.backward()
277
278
                    solverE.step()
279
                    SolverClass.step()
280
281
                    epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
282
                    epoch_cost3.append(AUC)
283
                    flag = 1
284
285
            if flag == 1:
286
                costtr.append(torch.mean(epoch_cost4))
287
                auctr.append(np.mean(epoch_cost3))
288
                print('Iter-{}; Total loss: {:.4}'.format(it, loss))
289
290
            with torch.no_grad():
291
292
                AutoencoderE.eval()
293
                Clas.eval()
294
295
                ZET = AutoencoderE(TX_testE)
296
                PredT = Clas(ZET)
297
298
                TripletsT = TripSel2(ZET, ty_testE)
299
                lossT = lam * trip_criterion(ZET[TripletsT[:,0],:], ZET[TripletsT[:,1],:], ZET[TripletsT[:,2],:]) + C_loss(PredT,ty_testE.view(-1,1))
300
301
                y_truet = ty_testE.view(-1,1)
302
                y_predt = PredT
303
                AUCt = roc_auc_score(y_truet.detach().numpy(),y_predt.detach().numpy())        
304
305
                costts.append(lossT)
306
                aucts.append(AUCt)
307
308
        plt.plot(np.squeeze(costtr), '-r',np.squeeze(costts), '-b')
309
        plt.ylabel('Total cost')
310
        plt.xlabel('iterations (per tens)')
311
312
        title = 'Cost Cisplatin iter = {}, fold = {}, mb_size = {},  h_dim = {}, marg = {}, lrE = {}, epoch = {}, rate = {}, wd = {}, lrCL = {}, lam = {}'.\
313
                      format(iters, k, mbs, hdm, mrg, lre, epch, rate, wd, lrCL, lam)
314
315
        plt.suptitle(title)
316
        plt.savefig(save_results_to + title + '.png', dpi = 150)
317
        plt.close()
318
319
        plt.plot(np.squeeze(auctr), '-r',np.squeeze(aucts), '-b')
320
        plt.ylabel('AUC')
321
        plt.xlabel('iterations (per tens)')
322
323
        title = 'AUC Cisplatin iter = {}, fold = {}, mb_size = {},  h_dim = {}, marg = {}, lrE = {}, epoch = {}, rate = {}, wd = {}, lrCL = {}, lam = {}'.\
324
                      format(iters, k, mbs, hdm, mrg, lre, epch, rate, wd, lrCL, lam)        
325
326
        plt.suptitle(title)
327
        plt.savefig(save_results_to + title + '.png', dpi = 150)
328
        plt.close()