Switch to unified view

a b/Cross validation/MOLI only classifier/ErlotinibPDX_cvClassifierNetv9_ScriptCPU.py
1
import torch 
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import torch.optim as optim
5
import numpy as np
6
import matplotlib
7
matplotlib.use('Agg')
8
import matplotlib.pyplot as plt
9
import matplotlib.gridspec as gridspec
10
import pandas as pd
11
import math
12
import sklearn.preprocessing as sk
13
import seaborn as sns
14
from sklearn import metrics
15
from sklearn.feature_selection import VarianceThreshold
16
from sklearn.model_selection import train_test_split
17
from torch.utils.data.sampler import WeightedRandomSampler
18
from sklearn.metrics import roc_auc_score
19
from sklearn.metrics import average_precision_score
20
import random
21
from sklearn.model_selection import StratifiedKFold
22
23
save_results_to = '/home/hnoghabi/CVClassifierResultsv9/Erlotinib/'
24
max_iter = 100
25
torch.manual_seed(42)
26
27
# Load Mutation
28
GDSCE = pd.read_csv("GDSC_exprs.Erlotinib.eb_with.PDX_exprs.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ",")
29
GDSCE = pd.DataFrame.transpose(GDSCE)
30
# Load GDSC response
31
GDSCR = pd.read_csv("GDSC_response.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ",")
32
33
PDXE = pd.read_csv("PDX_exprs.Erlotinib.eb_with.GDSC_exprs.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ",")
34
PDXE = pd.DataFrame.transpose(PDXE)
35
36
PDXM = pd.read_csv("PDX_mutations.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ".")
37
PDXM = pd.DataFrame.transpose(PDXM)
38
39
PDXC = pd.read_csv("PDX_CNA.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ".")
40
PDXC = pd.DataFrame.transpose(PDXC)
41
42
GDSCM = pd.read_csv("GDSC_mutations.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ".")
43
GDSCM = pd.DataFrame.transpose(GDSCM)
44
45
GDSCC = pd.read_csv("GDSC_CNA.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ".")
46
GDSCC.drop_duplicates(keep='last')
47
PDXC = PDXC.loc[:,~PDXC.columns.duplicated()]
48
GDSCC = pd.DataFrame.transpose(GDSCC)
49
selector = VarianceThreshold(0.05)
50
selector.fit_transform(GDSCE)
51
GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]
52
53
PDXC = PDXC.fillna(0)
54
PDXC[PDXC != 0.0] = 1
55
PDXM = PDXM.fillna(0)
56
PDXM[PDXM != 0.0] = 1
57
GDSCM = GDSCM.fillna(0)
58
GDSCM[GDSCM != 0.0] = 1
59
GDSCC = GDSCC.fillna(0)
60
GDSCC[GDSCC != 0.0] = 1
61
62
ls = GDSCE.columns.intersection(GDSCM.columns)
63
ls = ls.intersection(GDSCC.columns)
64
ls = ls.intersection(PDXE.columns)
65
ls = ls.intersection(PDXM.columns)
66
ls = ls.intersection(PDXC.columns)
67
ls2 = GDSCE.index.intersection(GDSCM.index)
68
ls2 = ls2.intersection(GDSCC.index)
69
ls3 = PDXE.index.intersection(PDXM.index)
70
ls3 = ls3.intersection(PDXC.index)
71
ls = pd.unique(ls)
72
73
PDXE = PDXE.loc[ls3,ls]
74
PDXM = PDXM.loc[ls3,ls]
75
PDXC = PDXC.loc[ls3,ls]
76
GDSCE = GDSCE.loc[ls2,ls]
77
GDSCM = GDSCM.loc[ls2,ls]
78
GDSCC = GDSCC.loc[ls2,ls]
79
80
GDSCR.loc[GDSCR.iloc[:,0] == 'R'] = 0
81
GDSCR.loc[GDSCR.iloc[:,0] == 'S'] = 1
82
GDSCR.columns = ['targets']
83
GDSCR = GDSCR.loc[ls2,:]
84
85
PDXR = pd.read_csv("PDX_response.Erlotinib.tsv", sep = "\t", index_col=0, decimal = ",")
86
PDXR.loc[PDXR.iloc[:,0] == 'R'] = 0
87
PDXR.loc[PDXR.iloc[:,0] == 'S'] = 1
88
89
Y = GDSCR['targets'].values
90
91
ls_mb_size = [15, 30, 60]
92
ls_h_dim = [1024, 512, 256, 128, 64]
93
ls_z_dim = [128, 64, 32, 16]
94
ls_marg = [0.5, 1, 1.5, 2, 2.5]
95
ls_lr = [0.05, 0.01, 0.001, 0.005, 0.0005, 0.0001,0.00005, 0.00001]
96
ls_epoch = [20, 50, 10, 15, 30, 40, 100]
97
ls_rate = [0.4, 0.5, 0.6, 0.7]
98
ls_wd = [0.01, 0.001, 0.1, 0.0001]
99
100
skf = StratifiedKFold(n_splits=10, random_state=42)
101
    
102
for iters in range(max_iter):
103
    k = 0
104
    mbs = random.choice(ls_mb_size)
105
    hdm = random.choice(ls_h_dim)
106
    zdm = random.choice(ls_z_dim)
107
    lre = random.choice(ls_lr)
108
    lrm = random.choice(ls_lr)
109
    lrc = random.choice(ls_lr)
110
    lrCL = random.choice(ls_lr)
111
    epch = random.choice(ls_epoch)
112
    wd = random.choice(ls_wd)
113
    rate = random.choice(ls_rate)
114
    
115
    for train_index, test_index in skf.split(GDSCE.values, Y):
116
        k = k + 1
117
        X_trainE = GDSCE.values[train_index,:]
118
        X_testE =  GDSCE.values[test_index,:]
119
        X_trainM = GDSCM.values[train_index,:]
120
        X_testM = GDSCM.values[test_index,:]
121
        X_trainC = GDSCC.values[train_index,:]
122
        X_testC = GDSCM.values[test_index,:]
123
        y_trainE = Y[train_index]
124
        y_testE = Y[test_index]
125
        
126
        scalerGDSC = sk.StandardScaler()
127
        scalerGDSC.fit(X_trainE)
128
        X_trainE = scalerGDSC.transform(X_trainE)
129
        X_testE = scalerGDSC.transform(X_testE)
130
131
        X_trainM = np.nan_to_num(X_trainM)
132
        X_trainC = np.nan_to_num(X_trainC)
133
        X_testM = np.nan_to_num(X_testM)
134
        X_testC = np.nan_to_num(X_testC)
135
        
136
        TX_testE = torch.FloatTensor(X_testE)
137
        TX_testM = torch.FloatTensor(X_testM)
138
        TX_testC = torch.FloatTensor(X_testC)
139
        ty_testE = torch.FloatTensor(y_testE.astype(int))
140
        
141
        #Train
142
        class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
143
        weight = 1. / class_sample_count
144
        samples_weight = np.array([weight[t] for t in y_trainE])
145
146
        samples_weight = torch.from_numpy(samples_weight)
147
        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)
148
149
        mb_size = mbs
150
151
        trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(X_trainM), 
152
                                                      torch.FloatTensor(X_trainC), torch.FloatTensor(y_trainE.astype(int)))
153
154
        trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)
155
156
        n_sampE, IE_dim = X_trainE.shape
157
        n_sampM, IM_dim = X_trainM.shape
158
        n_sampC, IC_dim = X_trainC.shape
159
160
        h_dim = hdm
161
        Z_dim = zdm
162
        Z_in = h_dim + h_dim + h_dim
163
        lrE = lre
164
        lrM = lrm
165
        lrC = lrc
166
        epoch = epch
167
168
        costtr = []
169
        auctr = []
170
        costts = []
171
        aucts = []
172
173
        class AEE(nn.Module):
174
            def __init__(self):
175
                super(AEE, self).__init__()
176
                self.EnE = torch.nn.Sequential(
177
                    nn.Linear(IE_dim, h_dim),
178
                    nn.BatchNorm1d(h_dim),
179
                    nn.ReLU(),
180
                    nn.Dropout())
181
            def forward(self, x):
182
                output = self.EnE(x)
183
                return output
184
185
        class AEM(nn.Module):
186
            def __init__(self):
187
                super(AEM, self).__init__()
188
                self.EnM = torch.nn.Sequential(
189
                    nn.Linear(IM_dim, h_dim),
190
                    nn.BatchNorm1d(h_dim),
191
                    nn.ReLU(),
192
                    nn.Dropout())
193
            def forward(self, x):
194
                output = self.EnM(x)
195
                return output    
196
197
198
        class AEC(nn.Module):
199
            def __init__(self):
200
                super(AEC, self).__init__()
201
                self.EnC = torch.nn.Sequential(
202
                    nn.Linear(IM_dim, h_dim),
203
                    nn.BatchNorm1d(h_dim),
204
                    nn.ReLU(),
205
                    nn.Dropout())
206
            def forward(self, x):
207
                output = self.EnC(x)
208
                return output      
209
        class Classifier(nn.Module):
210
            def __init__(self):
211
                super(Classifier, self).__init__()
212
                self.FC = torch.nn.Sequential(
213
                    nn.Linear(Z_in, Z_dim),
214
                    nn.ReLU(),
215
                    nn.Dropout(rate),
216
                    nn.Linear(Z_dim, 1),
217
                    nn.Dropout(rate),
218
                    nn.Sigmoid())
219
            def forward(self, x):
220
                return self.FC(x)
221
        
222
        torch.cuda.manual_seed_all(42)
223
224
        AutoencoderE = AEE()
225
        AutoencoderM = AEM()
226
        AutoencoderC = AEC()
227
228
        solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
229
        solverM = optim.Adagrad(AutoencoderM.parameters(), lr=lrM)
230
        solverC = optim.Adagrad(AutoencoderC.parameters(), lr=lrC)
231
232
        Clas = Classifier()
233
        SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
234
        C_loss = torch.nn.BCELoss()
235
236
        for it in range(epoch):
237
238
            epoch_cost4 = 0
239
            epoch_cost3 = []
240
            num_minibatches = int(n_sampE / mb_size) 
241
242
            for i, (dataE, dataM, dataC, target) in enumerate(trainLoader):
243
                flag = 0
244
                AutoencoderE.train()
245
                AutoencoderM.train()
246
                AutoencoderC.train()
247
                Clas.train()
248
249
                if torch.mean(target)!=0. and torch.mean(target)!=1.:                      
250
251
                    ZEX = AutoencoderE(dataE)
252
                    ZMX = AutoencoderM(dataM)
253
                    ZCX = AutoencoderC(dataC)
254
255
                    ZT = torch.cat((ZEX, ZMX, ZCX), 1)
256
                    ZT = F.normalize(ZT, p=2, dim=0)
257
258
                    Pred = Clas(ZT)
259
                    loss = C_loss(Pred,target.view(-1,1))   
260
261
                    y_true = target.view(-1,1)
262
                    y_pred = Pred
263
                    AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 
264
265
                    solverE.zero_grad()
266
                    solverM.zero_grad()
267
                    solverC.zero_grad()
268
                    SolverClass.zero_grad()
269
270
                    loss.backward()
271
272
                    solverE.step()
273
                    solverM.step()
274
                    solverC.step()
275
                    SolverClass.step()
276
                    
277
                    epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
278
                    epoch_cost3.append(AUC)
279
                    flag = 1
280
281
            if flag == 1:
282
                costtr.append(torch.mean(epoch_cost4))
283
                auctr.append(np.mean(epoch_cost3))
284
                print('Iter-{}; Total loss: {:.4}'.format(it, loss))
285
286
            with torch.no_grad():
287
288
                AutoencoderE.eval()
289
                AutoencoderM.eval()
290
                AutoencoderC.eval()
291
                Clas.eval()
292
293
                ZET = AutoencoderE(TX_testE)
294
                ZMT = AutoencoderM(TX_testM)
295
                ZCT = AutoencoderC(TX_testC)
296
297
                ZTT = torch.cat((ZET, ZMT, ZCT), 1)
298
                ZTT = F.normalize(ZTT, p=2, dim=0)
299
300
                PredT = Clas(ZTT)
301
                lossT = C_loss(PredT,ty_testE.view(-1,1))         
302
303
                y_truet = ty_testE.view(-1,1)
304
                y_predt = PredT
305
                AUCt = roc_auc_score(y_truet.detach().numpy(),y_predt.detach().numpy())
306
307
                costts.append(lossT)
308
                aucts.append(AUCt)
309
310
        plt.plot(np.squeeze(costtr), '-r',np.squeeze(costts), '-b')
311
        plt.ylabel('Total cost')
312
        plt.xlabel('iterations (per tens)')
313
314
        title = 'Cost Erlotinib iter = {}, fold = {}, mb_size = {},  hz_dim[1,2] = ({},{}), lr[E,M,C] = ({}, {}, {}), epoch = {}, wd = {}, lrCL = {}, rate4 = {}'.\
315
                      format(iters, k, mbs, hdm, zdm , lre, lrm, lrc, epch, wd, lrCL, rate)
316
317
        plt.suptitle(title)
318
        plt.savefig(save_results_to + title + '.png', dpi = 150)
319
        plt.close()
320
321
        plt.plot(np.squeeze(auctr), '-r',np.squeeze(aucts), '-b')
322
        plt.ylabel('AUC')
323
        plt.xlabel('iterations (per tens)')
324
325
        title = 'AUC Erlotinib iter = {}, fold = {}, mb_size = {},  hz_dim[1,2] = ({},{}), lr[E,M,C] = ({}, {}, {}), epoch = {}, wd = {}, lrCL = {}, rate4 = {}'.\
326
                      format(iters, k, mbs, hdm, zdm , lre, lrm, lrc, epch, wd, lrCL, rate)        
327
328
        plt.suptitle(title)
329
        plt.savefig(save_results_to + title + '.png', dpi = 150)
330
        plt.close()