a b/Cross validation/MOLI only expression/Paclitaxel_OnlyExprsv2_Script.py
1
import torch 
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import torch.optim as optim
5
import numpy as np
6
import matplotlib
7
matplotlib.use('Agg')
8
import matplotlib.pyplot as plt
9
import matplotlib.gridspec as gridspec
10
import pandas as pd
11
import math
12
import sklearn.preprocessing as sk
13
import seaborn as sns
14
from sklearn import metrics
15
from sklearn.feature_selection import VarianceThreshold
16
from sklearn.model_selection import train_test_split
17
from utils import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
18
from metrics import AverageNonzeroTripletsMetric
19
from torch.utils.data.sampler import WeightedRandomSampler
20
from sklearn.metrics import roc_auc_score
21
from sklearn.metrics import average_precision_score
22
import random
23
from random import randint
24
from sklearn.model_selection import StratifiedKFold
25
26
save_results_to = '/home/hnoghabi/OnlyExprsv2/Paclitaxel/'
27
torch.manual_seed(42)
28
random.seed(42)
29
30
max_iter = 50
31
32
GDSCE = pd.read_csv("GDSC_exprs.Paclitaxel.eb_with.PDX_exprs.Paclitaxel.tsv", 
33
                    sep = "\t", index_col=0, decimal = ",")
34
GDSCE = pd.DataFrame.transpose(GDSCE)
35
36
GDSCR = pd.read_csv("GDSC_response.Paclitaxel.tsv", 
37
                    sep = "\t", index_col=0, decimal = ",")
38
39
PDXE = pd.read_csv("PDX_exprs.Paclitaxel.eb_with.GDSC_exprs.Paclitaxel.tsv", 
40
                   sep = "\t", index_col=0, decimal = ",")
41
PDXE = pd.DataFrame.transpose(PDXE)
42
43
PDXM = pd.read_csv("PDX_mutations.Paclitaxel.tsv", 
44
                   sep = "\t", index_col=0, decimal = ",")
45
PDXM = pd.DataFrame.transpose(PDXM)
46
47
PDXC = pd.read_csv("PDX_CNA.Paclitaxel.tsv", 
48
                   sep = "\t", index_col=0, decimal = ",")
49
PDXC.drop_duplicates(keep='last')
50
PDXC = pd.DataFrame.transpose(PDXC)
51
PDXC = PDXC.loc[:,~PDXC.columns.duplicated()]
52
53
GDSCM = pd.read_csv("GDSC_mutations.Paclitaxel.tsv", 
54
                    sep = "\t", index_col=0, decimal = ",")
55
GDSCM = pd.DataFrame.transpose(GDSCM)
56
57
58
GDSCC = pd.read_csv("GDSC_CNA.Paclitaxel.tsv", 
59
                    sep = "\t", index_col=0, decimal = ",")
60
GDSCC.drop_duplicates(keep='last')
61
GDSCC = pd.DataFrame.transpose(GDSCC)
62
63
selector = VarianceThreshold(0.05)
64
selector.fit_transform(GDSCE)
65
GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]
66
67
ls = GDSCE.columns.intersection(GDSCM.columns)
68
ls = ls.intersection(GDSCC.columns)
69
ls = ls.intersection(PDXE.columns)
70
ls = ls.intersection(PDXM.columns)
71
ls = ls.intersection(PDXC.columns)
72
ls2 = GDSCE.index.intersection(GDSCM.index)
73
ls2 = ls2.intersection(GDSCC.index)
74
ls3 = PDXE.index.intersection(PDXM.index)
75
ls3 = ls3.intersection(PDXC.index)
76
ls = pd.unique(ls)
77
78
PDXE = PDXE.loc[ls3,ls]
79
PDXM = PDXM.loc[ls3,ls]
80
PDXC = PDXC.loc[ls3,ls]
81
GDSCE = GDSCE.loc[ls2,ls]
82
GDSCM = GDSCM.loc[ls2,ls]
83
GDSCC = GDSCC.loc[ls2,ls]
84
85
GDSCR.loc[GDSCR.iloc[:,0] == 'R'] = 0
86
GDSCR.loc[GDSCR.iloc[:,0] == 'S'] = 1
87
GDSCR.columns = ['targets']
88
GDSCR = GDSCR.loc[ls2,:]
89
90
PDXR = pd.read_csv("PDX_response.Paclitaxel.tsv", 
91
                       sep = "\t", index_col=0, decimal = ",")
92
PDXR.loc[PDXR.iloc[:,0] == 'R'] = 0
93
PDXR.loc[PDXR.iloc[:,0] == 'S'] = 1
94
95
96
ls_mb_size = [13, 36, 64]
97
ls_h_dim = [512, 256, 128, 64]
98
ls_marg = [0.5, 1, 1.5, 2]
99
ls_lr = [0.5, 0.1, 0.05, 0.01, 0.001, 0.005, 0.0005, 0.0001,0.00005, 0.00001]
100
ls_epoch = [20, 50, 10, 15, 30, 40, 60, 70, 80, 90, 100]
101
ls_rate = [0.3, 0.4, 0.5]
102
ls_wd = [0.01, 0.001]
103
ls_lam = [0.1, 0.5, 0.01, 0.05, 0.001, 0.005]
104
105
Y = GDSCR['targets'].values
106
107
skf = StratifiedKFold(n_splits=5, random_state=42)
108
    
109
for iters in range(max_iter):
110
    k = 0
111
    mbs = random.choice(ls_mb_size)
112
    hdm = random.choice(ls_h_dim)
113
    mrg = random.choice(ls_marg)
114
    lre = random.choice(ls_lr)
115
    lrCL = random.choice(ls_lr)
116
    epch = random.choice(ls_epoch)
117
    rate = random.choice(ls_rate)
118
    wd = random.choice(ls_wd)   
119
    lam = random.choice(ls_lam)       
120
121
    for train_index, test_index in skf.split(GDSCE.values, Y):
122
        k = k + 1
123
        X_trainE = GDSCE.values[train_index,:]
124
        X_testE =  GDSCE.values[test_index,:]
125
        y_trainE = Y[train_index]
126
        y_testE = Y[test_index]
127
        
128
        scalerGDSC = sk.StandardScaler()
129
        scalerGDSC.fit(X_trainE)
130
        X_trainE = scalerGDSC.transform(X_trainE)
131
        X_testE = scalerGDSC.transform(X_testE)
132
        
133
        TX_testE = torch.FloatTensor(X_testE)
134
        ty_testE = torch.FloatTensor(y_testE.astype(int))
135
        
136
        #Train
137
        class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
138
        weight = 1. / class_sample_count
139
        samples_weight = np.array([weight[t] for t in y_trainE])
140
141
        samples_weight = torch.from_numpy(samples_weight)
142
        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)
143
144
        mb_size = mbs
145
146
        trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(y_trainE.astype(int)))
147
148
        trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)
149
150
        n_sampE, IE_dim = X_trainE.shape
151
152
        h_dim = hdm
153
        Z_in = h_dim
154
        marg = mrg
155
        lrE = lre
156
        epoch = epch
157
158
        costtr = []
159
        auctr = []
160
        costts = []
161
        aucts = []
162
163
        triplet_selector = RandomNegativeTripletSelector(marg)
164
        triplet_selector2 = AllTripletSelector()
165
166
        class AEE(nn.Module):
167
            def __init__(self):
168
                super(AEE, self).__init__()
169
                self.EnE = torch.nn.Sequential(
170
                    nn.Linear(IE_dim, h_dim),
171
                    nn.BatchNorm1d(h_dim),
172
                    nn.ReLU(),
173
                    nn.Dropout())
174
            def forward(self, x):
175
                output = self.EnE(x)
176
                return output  
177
178
        class OnlineTriplet(nn.Module):
179
            def __init__(self, marg, triplet_selector):
180
                super(OnlineTriplet, self).__init__()
181
                self.marg = marg
182
                self.triplet_selector = triplet_selector
183
            def forward(self, embeddings, target):
184
                triplets = self.triplet_selector.get_triplets(embeddings, target)
185
                return triplets
186
187
        class OnlineTestTriplet(nn.Module):
188
            def __init__(self, marg, triplet_selector):
189
                super(OnlineTestTriplet, self).__init__()
190
                self.marg = marg
191
                self.triplet_selector = triplet_selector
192
            def forward(self, embeddings, target):
193
                triplets = self.triplet_selector.get_triplets(embeddings, target)
194
                return triplets    
195
196
        class Classifier(nn.Module):
197
            def __init__(self):
198
                super(Classifier, self).__init__()
199
                self.FC = torch.nn.Sequential(
200
                    nn.Linear(Z_in, 1),
201
                    nn.Dropout(rate),
202
                    nn.Sigmoid())
203
            def forward(self, x):
204
                return self.FC(x)
205
206
        torch.cuda.manual_seed_all(42)
207
208
        AutoencoderE = AEE()
209
210
211
        solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
212
213
        trip_criterion = torch.nn.TripletMarginLoss(margin=marg, p=2)
214
        TripSel = OnlineTriplet(marg, triplet_selector)
215
        TripSel2 = OnlineTestTriplet(marg, triplet_selector2)
216
217
        Clas = Classifier()
218
        SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
219
        C_loss = torch.nn.BCELoss()
220
221
        for it in range(epoch):
222
223
            epoch_cost4 = 0
224
            epoch_cost3 = []
225
            num_minibatches = int(n_sampE / mb_size) 
226
227
            for i, (dataE, target) in enumerate(trainLoader):
228
                flag = 0
229
                AutoencoderE.train()
230
231
                Clas.train()
232
233
                if torch.mean(target)!=0. and torch.mean(target)!=1.: 
234
                    ZEX = AutoencoderE(dataE)
235
                    Pred = Clas(ZEX)
236
237
                    Triplets = TripSel2(ZEX, target)
238
                    loss = lam * trip_criterion(ZEX[Triplets[:,0],:],ZEX[Triplets[:,1],:],ZEX[Triplets[:,2],:]) + C_loss(Pred,target.view(-1,1))     
239
240
                    y_true = target.view(-1,1)
241
                    y_pred = Pred
242
                    AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 
243
244
                    solverE.zero_grad()
245
                    SolverClass.zero_grad()
246
247
                    loss.backward()
248
249
                    solverE.step()
250
                    SolverClass.step()
251
252
                    epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
253
                    epoch_cost3.append(AUC)
254
                    flag = 1
255
256
            if flag == 1:
257
                costtr.append(torch.mean(epoch_cost4))
258
                auctr.append(np.mean(epoch_cost3))
259
                print('Iter-{}; Total loss: {:.4}'.format(it, loss))
260
261
            with torch.no_grad():
262
263
                AutoencoderE.eval()
264
                Clas.eval()
265
266
                ZET = AutoencoderE(TX_testE)
267
                PredT = Clas(ZET)
268
269
                TripletsT = TripSel2(ZET, ty_testE)
270
                lossT = lam * trip_criterion(ZET[TripletsT[:,0],:], ZET[TripletsT[:,1],:], ZET[TripletsT[:,2],:]) + C_loss(PredT,ty_testE.view(-1,1))
271
272
                y_truet = ty_testE.view(-1,1)
273
                y_predt = PredT
274
                AUCt = roc_auc_score(y_truet.detach().numpy(),y_predt.detach().numpy())        
275
276
                costts.append(lossT)
277
                aucts.append(AUCt)
278
279
        plt.plot(np.squeeze(costtr), '-r',np.squeeze(costts), '-b')
280
        plt.ylabel('Total cost')
281
        plt.xlabel('iterations (per tens)')
282
283
        title = 'Cost Paclitaxel iter = {}, fold = {}, mb_size = {},  h_dim = {}, marg = {}, lrE = {}, epoch = {}, rate = {}, wd = {}, lrCL = {}, lam = {}'.\
284
                      format(iters, k, mbs, hdm, mrg, lre, epch, rate, wd, lrCL, lam)
285
286
        plt.suptitle(title)
287
        plt.savefig(save_results_to + title + '.png', dpi = 150)
288
        plt.close()
289
290
        plt.plot(np.squeeze(auctr), '-r',np.squeeze(aucts), '-b')
291
        plt.ylabel('AUC')
292
        plt.xlabel('iterations (per tens)')
293
294
        title = 'AUC Paclitaxel iter = {}, fold = {}, mb_size = {},  h_dim = {}, marg = {}, lrE = {}, epoch = {}, rate = {}, wd = {}, lrCL = {}, lam = {}'.\
295
                      format(iters, k, mbs, hdm, mrg, lre, epch, rate, wd, lrCL, lam)        
296
297
        plt.suptitle(title)
298
        plt.savefig(save_results_to + title + '.png', dpi = 150)
299
        plt.close()