Diff of /model/SALMON.py [000000] .. [a23a6e]

Switch to unified view

a b/model/SALMON.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
@author: Zhi Huang
5
"""
6
import argparse, random
7
import torch
8
import torch.nn as nn
9
import torch.backends.cudnn as cudnn
10
import torch.nn.functional as F
11
import torch.optim as optim
12
from torch.utils.data import DataLoader
13
from torchvision import datasets, transforms
14
from torch.autograd import Variable
15
from collections import Counter
16
import pandas as pd
17
import matplotlib.pyplot as plt
18
import math
19
import random
20
from imblearn.over_sampling import RandomOverSampler
21
import pandas as pd
22
from lifelines.statistics import logrank_test
23
from lifelines.utils import concordance_index
24
import tables
25
import csv
26
import numpy as np
27
import json
28
from tqdm import tqdm
29
import gc
30
import copy
31
32
33
class SALMON(nn.Module):
34
    def __init__(self, input_dim, dropout_rate, length_of_data, label_dim):
35
        super(SALMON, self).__init__()
36
        
37
        self.length_of_data = length_of_data
38
        hidden1 = 8
39
        hidden2 = 4
40
        
41
        if input_dim == length_of_data['mRNAseq']: # mRNAseq
42
            self.encoder1 = nn.Sequential(nn.Linear(input_dim, hidden1),nn.Sigmoid())
43
            self.classifier = nn.Sequential(nn.Linear(hidden1, label_dim),nn.Sigmoid())
44
            
45
        if input_dim == length_of_data['miRNAseq']: # miRNAseq
46
            self.encoder2 = nn.Sequential(nn.Linear(input_dim, hidden2),nn.Sigmoid())
47
            self.classifier = nn.Sequential(nn.Linear(hidden2, label_dim),nn.Sigmoid())
48
            
49
        if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq']: # mRNAseq + miRNAseq
50
            self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid())
51
            self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid())
52
            self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2, label_dim),nn.Sigmoid())
53
            
54
        if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq'] + length_of_data['CNB'] + length_of_data['TMB']: # mRNAseq + miRNAseq + CNB + TMB
55
            hidden_cnv, hidden_tmb = length_of_data['CNB'], length_of_data['TMB']
56
            self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid())
57
            self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid())
58
            self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2 + hidden_cnv + hidden_tmb, label_dim),nn.Sigmoid())
59
                        
60
        if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq'] + length_of_data['CNB'] + length_of_data['TMB'] + length_of_data['clinical']: # mRNAseq + miRNAseq + CNB + TMB + clinical
61
            hidden_cnv, hidden_tmb, hidden_clinical = length_of_data['CNB'], length_of_data['TMB'], length_of_data['clinical']
62
            self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid())
63
            self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid())
64
            self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2 + \
65
                                            hidden_cnv + hidden_tmb + hidden_clinical, label_dim),nn.Sigmoid())
66
            
67
        if input_dim == length_of_data['CNB'] + length_of_data['TMB'] + length_of_data['clinical']: # CNB + TMB + clinical
68
            hidden_cnv, hidden_tmb, hidden_clinical = length_of_data['CNB'], length_of_data['TMB'], length_of_data['clinical']
69
            self.classifier = nn.Sequential(nn.Linear(hidden_cnv + hidden_tmb + hidden_clinical, label_dim),nn.Sigmoid())
70
        
71
        if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq'] + length_of_data['clinical']: # mRNAseq + miRNAseq + clinical
72
            hidden_clinical = length_of_data['clinical']
73
            self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid())
74
            self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid())
75
            self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2 + \
76
                                            hidden_clinical, label_dim),nn.Sigmoid())
77
        
78
    def forward(self, x):
79
        input_dim = x.shape[1]
80
        x_d = None
81
        if input_dim == self.length_of_data['mRNAseq']: # mRNAseq
82
            code1 = self.encoder1(x)
83
            lbl_pred = self.classifier(code1) # predicted label
84
            code = code1
85
            
86
        if input_dim == self.length_of_data['miRNAseq']: # miRNAseq
87
            code2 = self.encoder2(x)
88
            lbl_pred = self.classifier(code2) # predicted label
89
            code = code2
90
            
91
        if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']: # mRNAseq + miRNAseq
92
            code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']])
93
            code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']:])
94
            lbl_pred = self.classifier(torch.cat((code1, code2), 1)) # predicted label
95
            code = torch.cat((code1, code2), 1)
96
            
97
        if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'] + self.length_of_data['CNB'] + self.length_of_data['TMB']: # mRNAseq + miRNAseq + CNB + TMB
98
            code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']])
99
            code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']: (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'])])
100
            lbl_pred = self.classifier(torch.cat((code1, code2, x[:,(self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']):]), 1)) # predicted label
101
            code = torch.cat((code1, code2), 1)
102
                        
103
        if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'] + self.length_of_data['CNB'] + self.length_of_data['TMB'] + self.length_of_data['clinical']: # mRNAseq + miRNAseq + CNB + TMB + clinical
104
            code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']])
105
            code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']: (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'])])
106
            lbl_pred = self.classifier(torch.cat((code1, code2, x[:, (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']):]), 1)) # predicted label
107
            code = torch.cat((code1, code2), 1)
108
            
109
        if input_dim == self.length_of_data['CNB'] + self.length_of_data['TMB'] + self.length_of_data['clinical']: # CNB + TMB + clinical
110
            lbl_pred = self.classifier(x) # predicted label
111
            code = torch.FloatTensor([0])
112
            
113
        if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'] + self.length_of_data['clinical']: # mRNAseq + miRNAseq + clinical
114
            code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']])
115
            code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']: (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'])])
116
            lbl_pred = self.classifier(torch.cat((code1, code2, x[:, (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']):]), 1)) # predicted label
117
            code = torch.cat((code1, code2), 1)
118
            
119
        return x_d, code, lbl_pred
120
121
122
def accuracy(output, labels):
123
    preds = output.max(1)[1].type_as(labels)
124
    correct = preds.eq(labels).double()
125
    correct = correct.sum()
126
    return correct / len(labels)
127
128
def accuracy_cox(hazards, labels):
129
    # This accuracy is based on estimated survival events against true survival events
130
    hazardsdata = hazards.cpu().numpy().reshape(-1)
131
    median = np.median(hazardsdata)
132
    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
133
    hazards_dichotomize[hazardsdata > median] = 1
134
    labels = labels.data.cpu().numpy()
135
    correct = np.sum(hazards_dichotomize == labels)
136
    return correct / len(labels)
137
138
def cox_log_rank(hazards, labels, survtime_all):
139
    hazardsdata = hazards.cpu().numpy().reshape(-1)
140
    median = np.median(hazardsdata)
141
    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
142
    hazards_dichotomize[hazardsdata > median] = 1
143
    survtime_all = survtime_all.data.cpu().numpy().reshape(-1)
144
    idx = hazards_dichotomize == 0
145
    labels = labels.data.cpu().numpy()
146
    T1 = survtime_all[idx]
147
    T2 = survtime_all[~idx]
148
    E1 = labels[idx]
149
    E2 = labels[~idx]
150
    results = logrank_test(T1, T2, event_observed_A=E1, event_observed_B=E2)
151
    pvalue_pred = results.p_value
152
    return(pvalue_pred)
153
    
154
def CIndex(hazards, labels, survtime_all):
155
    labels = labels.data.cpu().numpy()
156
    concord = 0.
157
    total = 0.
158
    N_test = labels.shape[0]
159
    labels = np.asarray(labels, dtype=bool)
160
    for i in range(N_test):
161
        if labels[i] == 1:
162
            for j in range(N_test):
163
                if survtime_all[j] > survtime_all[i]:
164
                    total = total + 1
165
                    if hazards[j] < hazards[i]: concord = concord + 1
166
                    elif hazards[j] < hazards[i]: concord = concord + 0.5
167
168
    return(concord/total)
169
    
170
def CIndex_lifeline(hazards, labels, survtime_all):
171
    labels = labels.data.cpu().numpy()
172
    hazards = hazards.cpu().numpy().reshape(-1)
173
    return(concordance_index(survtime_all, -hazards, labels))
174
        
175
def frobenius_norm_loss(a, b):
176
    loss = torch.sqrt(torch.sum(torch.abs(a-b)**2))
177
    return loss
178
179
def test(model, datasets, whichset, length_of_data, batch_size, cuda, verbose):
180
    x = datasets[whichset]['x']
181
    e = datasets[whichset]['e']
182
    t = datasets[whichset]['t']
183
    X = torch.FloatTensor(x)
184
    OS_event = torch.LongTensor(e)
185
    OS = torch.FloatTensor(t)
186
    dataloader = DataLoader(X, batch_size=batch_size, num_workers=1, pin_memory=True, shuffle=False)
187
    lblloader = DataLoader(OS_event, batch_size=batch_size, num_workers=1, pin_memory=True, shuffle=False)
188
    OSloader = DataLoader(OS, batch_size=batch_size, num_workers=1, pin_memory=True, shuffle=False)
189
    lbl_pred_all = None
190
    lbl_all = None
191
    survtime_all = None
192
    code_final = None
193
    loss_nn_sum = 0
194
    model.eval()
195
    iter = 0
196
    for data, lbl, survtime in zip(dataloader, lblloader, OSloader):
197
        graph = data
198
        graph = Variable(graph)
199
        lbl = Variable(lbl)
200
        if cuda:
201
            model = model.cuda()
202
            graph = graph.cuda()
203
            lbl = lbl.cuda()
204
        # ===================forward=====================
205
        output, code, lbl_pred = model(graph)
206
        if iter == 0:
207
            lbl_pred_all = lbl_pred
208
            lbl_all = lbl
209
            survtime_all = survtime
210
            code_final = code
211
        else:
212
            lbl_pred_all = torch.cat([lbl_pred_all, lbl_pred])
213
            lbl_all = torch.cat([lbl_all, lbl])
214
            survtime_all = torch.cat([survtime_all, survtime])
215
            code_final = torch.cat([code_final, code])
216
            
217
        current_batch_len = len(survtime)
218
        R_matrix_test = np.zeros([current_batch_len, current_batch_len], dtype=int)
219
        for i in range(current_batch_len):
220
            for j in range(current_batch_len):
221
                R_matrix_test[i,j] = survtime[j] >= survtime[i]
222
    
223
        test_R = torch.FloatTensor(R_matrix_test)
224
        test_R = Variable(test_R)
225
        if cuda:
226
            test_R = test_R.cuda()
227
        test_ystatus = lbl
228
        theta = lbl_pred.reshape(-1)
229
        exp_theta = torch.exp(theta)
230
        loss_nn = -torch.mean( (theta - torch.log(torch.sum( exp_theta*test_R ,dim=1))) * test_ystatus.float() )
231
        loss_nn_sum = loss_nn_sum + loss_nn.data.item()
232
        iter += 1
233
    code_final_4_original_data = code_final.data.cpu().numpy()
234
    acc_test = accuracy_cox(lbl_pred_all.data, lbl_all)
235
    pvalue_pred = cox_log_rank(lbl_pred_all.data, lbl_all, survtime_all)
236
    c_index = CIndex_lifeline(lbl_pred_all.data, lbl_all, survtime_all)
237
    if verbose > 0:
238
        print('\n[{:s}]\t\tloss (nn):{:.4f}'.format(whichset, loss_nn_sum),
239
                      'c_index: {:.4f}, p-value: {:.3e}'.format(c_index, pvalue_pred))
240
    return(code_final_4_original_data, loss_nn_sum, acc_test, \
241
           pvalue_pred, c_index, lbl_pred_all.data.cpu().numpy().reshape(-1), OS_event, survtime_all)
242
    
243
def init_weights(m):
244
    if type(m) == nn.Linear:
245
        m.weight.data.normal_(0, 0.5)
246
    
247
def train(datasets, num_epochs, batch_size, learning_rate, dropout_rate,
248
                        lambda_1, length_of_data, cuda, measure, verbose):
249
    
250
251
    x = datasets['train']['x']
252
    e = datasets['train']['e']
253
    t = datasets['train']['t']
254
    nodes_in = x.shape[1]
255
    
256
    X = torch.FloatTensor(x)
257
    OS_event = torch.LongTensor(e)
258
    OS = torch.FloatTensor(t)
259
        
260
    
261
    dataloader = DataLoader(X, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=False)
262
    lblloader = DataLoader(OS_event, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=False)
263
    OSloader = DataLoader(OS, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=False)
264
    
265
    
266
    
267
    cudnn.deterministic = True
268
    torch.cuda.manual_seed_all(666)
269
    torch.manual_seed(666)
270
    random.seed(666)
271
    
272
    model = SALMON(nodes_in, dropout_rate, length_of_data, label_dim = 1)
273
        
274
    if cuda:
275
        model.cuda()
276
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
277
    
278
    c_index_list = {}
279
    c_index_list['train'] = []
280
    c_index_list['test'] = []
281
    loss_nn_all = []
282
    pvalue_all = []
283
    c_index_all = []
284
    acc_train_all = []
285
    c_index_best = 0
286
    code_output = None
287
    
288
289
    for epoch in tqdm(range(num_epochs)):
290
        model.train()
291
        lbl_pred_all = None
292
        lbl_all = None
293
        survtime_all = None
294
        code_final = None
295
        loss_nn_sum = 0
296
        iter = 0
297
        gc.collect()
298
        for data, lbl, survtime in zip(dataloader, lblloader, OSloader):
299
            optimizer.zero_grad() # zero the gradient buffer
300
            graph = data
301
            if cuda:
302
                model = model.cuda()
303
                graph = graph.cuda()
304
                lbl = lbl.cuda()
305
            # ===================forward=====================
306
            output, code, lbl_pred = model(graph)
307
            
308
            if iter == 0:
309
                lbl_pred_all = lbl_pred
310
                survtime_all = survtime
311
                lbl_all = lbl
312
                code_final = code
313
            else:
314
                lbl_pred_all = torch.cat([lbl_pred_all, lbl_pred])
315
                lbl_all = torch.cat([lbl_all, lbl])
316
                survtime_all = torch.cat([survtime_all, survtime])
317
                code_final = torch.cat([code_final, code])
318
            # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
319
            # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
320
            current_batch_len = len(survtime)
321
            R_matrix_train = np.zeros([current_batch_len, current_batch_len], dtype=int)
322
            for i in range(current_batch_len):
323
                for j in range(current_batch_len):
324
                    R_matrix_train[i,j] = survtime[j] >= survtime[i]
325
        
326
            train_R = torch.FloatTensor(R_matrix_train)
327
            if cuda:
328
                train_R = train_R.cuda()
329
            train_ystatus = lbl
330
            
331
            theta = lbl_pred.reshape(-1)
332
            exp_theta = torch.exp(theta)
333
            
334
            loss_nn = -torch.mean( (theta - torch.log(torch.sum( exp_theta*train_R ,dim=1))) * train_ystatus.float() )
335
336
            l1_reg = None
337
            for W in model.parameters():
338
                if l1_reg is None:
339
                    l1_reg = torch.abs(W).sum()
340
                else:
341
                    l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
342
            
343
            loss = loss_nn + lambda_1 * l1_reg
344
            if verbose > 0:
345
                print("\nloss_nn: %.4f, L1: %.4f" % (loss_nn, lambda_1 * l1_reg))
346
            loss_nn_sum = loss_nn_sum + loss_nn.data.item()
347
            # ===================backward====================
348
            loss.backward()
349
            optimizer.step()
350
            
351
            iter += 1
352
            torch.cuda.empty_cache()
353
        code_final_4_original_data = code_final.data.cpu().numpy()
354
        
355
        if measure or epoch == (num_epochs - 1):
356
            acc_train = accuracy_cox(lbl_pred_all.data, lbl_all)
357
            pvalue_pred = cox_log_rank(lbl_pred_all.data, lbl_all, survtime_all)
358
            c_index = CIndex_lifeline(lbl_pred_all.data, lbl_all, survtime_all)
359
            
360
            c_index_list['train'].append(c_index)
361
            if c_index > c_index_best:
362
                c_index_best = c_index
363
                code_output = code_final_4_original_data
364
            if verbose > 0:
365
                print('\n[Training]\t loss (nn):{:.4f}'.format(loss_nn_sum),
366
                      'c_index: {:.4f}, p-value: {:.3e}'.format(c_index, pvalue_pred))
367
            pvalue_all.append(pvalue_pred)
368
            c_index_all.append(c_index)
369
            loss_nn_all.append(loss_nn_sum)
370
            acc_train_all.append(acc_train)
371
            whichset = 'test'
372
            code_validation, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all, OS_event, OS = \
373
                test(model, datasets, whichset, length_of_data, batch_size, cuda, verbose)
374
                
375
            c_index_list['test'].append(c_index_pred)
376
    return(model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output)