Diff of /HINT/model.py [000000] .. [bc9e98]

Switch to unified view

a b/HINT/model.py
1
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score
2
import matplotlib.pyplot as plt
3
from copy import deepcopy 
4
import numpy as np 
5
from tqdm import tqdm 
6
import torch 
7
torch.manual_seed(0)
8
from torch import nn 
9
from torch.autograd import Variable
10
import torch.nn.functional as F
11
from HINT.module import Highway, GCN 
12
from functools import reduce 
13
import pickle
14
15
16
class Interaction(nn.Sequential):
17
    def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
18
                    device, 
19
                    global_embed_size, 
20
                    highway_num_layer,
21
                    prefix_name, 
22
                    epoch = 20,
23
                    lr = 3e-4, 
24
                    weight_decay = 0, 
25
                    ):
26
        super(Interaction, self).__init__()
27
        self.molecule_encoder = molecule_encoder 
28
        self.disease_encoder = disease_encoder 
29
        self.protocol_encoder = protocol_encoder 
30
        self.global_embed_size = global_embed_size 
31
        self.highway_num_layer = highway_num_layer 
32
        self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size
33
        self.epoch = epoch 
34
        self.lr = lr 
35
        self.weight_decay = weight_decay 
36
        self.save_name = prefix_name + '_interaction'
37
38
        self.f = F.relu
39
        self.loss = nn.BCEWithLogitsLoss()
40
41
        ##### NN 
42
        self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size).to(device)
43
        self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
44
        self.pred_nn = nn.Linear(self.global_embed_size, 1)
45
46
        self.device = device 
47
        self = self.to(device)
48
49
    def feed_lst_of_module(self, input_feature, lst_of_module):
50
        x = input_feature
51
        for single_module in lst_of_module:
52
            x = self.f(single_module(x))
53
        return x
54
55
    def forward_get_three_encoders(self, smiles_lst2, icdcode_lst3, criteria_lst):
56
        molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2)
57
        icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3)
58
        protocol_embed = self.protocol_encoder.forward(criteria_lst)
59
        return molecule_embed, icd_embed, protocol_embed    
60
61
    def forward_encoder_2_interaction(self, molecule_embed, icd_embed, protocol_embed):
62
        encoder_embedding = torch.cat([molecule_embed, icd_embed, protocol_embed], 1)
63
        # interaction_embedding = self.feed_lst_of_module(encoder_embedding, [self.encoder2interaction_fc, self.encoder2interaction_highway])
64
        h = self.encoder2interaction_fc(encoder_embedding)
65
        h = self.f(h)
66
        h = self.encoder2interaction_highway(h)
67
        interaction_embedding = self.f(h)
68
        return interaction_embedding 
69
70
    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
71
        molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
72
        interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
73
        output = self.pred_nn(interaction_embedding)
74
        return output ### 32, 1
75
76
    def evaluation(self, predict_all, label_all, threshold = 0.5):
77
        import pickle, os
78
        from sklearn.metrics import roc_curve, precision_recall_curve
79
        with open("predict_label.txt", 'w') as fout:
80
            for i,j in zip(predict_all, label_all):
81
                fout.write(str(i)[:6] + '\t' + str(j)[:4]+'\n')
82
        auc_score = roc_auc_score(label_all, predict_all)
83
        figure_folder = "figure"
84
        #### ROC-curve 
85
        fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1)
86
        # roc_curve =plt.figure()
87
        # plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ')
88
        # plt.legend(fontsize = 15)
89
        # plt.savefig(os.path.join(figure_folder,self.save_name+"_roc_curve.png"))
90
        #### PR-curve
91
        precision, recall, thresholds = precision_recall_curve(label_all, predict_all)
92
        # plt.plot(recall,precision, label = self.save_name + ' PR Curve')
93
        # plt.legend(fontsize = 15)
94
        # plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png"))
95
        label_all = [int(i) for i in label_all]
96
        float2binary = lambda x:0 if x < threshold else 1
97
        predict_all = list(map(float2binary, predict_all))
98
        f1score = f1_score(label_all, predict_all)
99
        prauc_score = average_precision_score(label_all, predict_all)
100
        # print(predict_all)
101
        precision = precision_score(label_all, predict_all)
102
        recall = recall_score(label_all, predict_all)
103
        accuracy = accuracy_score(label_all, predict_all)
104
        predict_1_ratio = sum(predict_all) / len(predict_all)
105
        label_1_ratio = sum(label_all) / len(label_all)
106
        return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
107
108
    def testloader_to_lst(self, dataloader):
109
        nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst = [], [], [], [], []
110
        for nctid, label, smiles, icdcode, criteria in dataloader:
111
            nctid_lst.extend(nctid)
112
            label_lst.extend([i.item() for i in label])
113
            smiles_lst2.extend(smiles)
114
            icdcode_lst3.extend(icdcode)
115
            criteria_lst.extend(criteria)
116
        length = len(nctid_lst)
117
        assert length == len(smiles_lst2) and length == len(icdcode_lst3)
118
        return nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst, length 
119
120
    def generate_predict(self, dataloader):
121
        whole_loss = 0 
122
        label_all, predict_all, nctid_all = [], [], []
123
        for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader:
124
            nctid_all.extend(nctid_lst)
125
            label_vec = label_vec.to(self.device)
126
            output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1)  
127
            loss = self.loss(output, label_vec.float())
128
            whole_loss += loss.item()
129
            predict_all.extend([i.item() for i in torch.sigmoid(output)])
130
            label_all.extend([i.item() for i in label_vec])
131
132
        return whole_loss, predict_all, label_all, nctid_all
133
134
    def bootstrap_test(self, dataloader, valid_loader = None, sample_num = 20):
135
        best_threshold = 0.5
136
        # if validloader is not None:
137
        #   best_threshold = self.select_threshold_for_binary(valid_loader)
138
        #   print(f"best_threshold: {best_threshold}")
139
        self.eval()
140
        whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader)
141
        from HINT.utils import plot_hist
142
        plt.clf()
143
        prefix_name = "./figure/" + self.save_name 
144
        plot_hist(prefix_name, predict_all, label_all)      
145
        def bootstrap(length, sample_num):
146
            idx = [i for i in range(length)]
147
            from random import choices 
148
            bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)]
149
            return bootstrap_idx 
150
        results_lst = []
151
        bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num)
152
        for bootstrap_idx in bootstrap_idx_lst: 
153
            bootstrap_label = [label_all[idx] for idx in bootstrap_idx]     
154
            bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx]
155
            results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold)
156
            results_lst.append(results)
157
        self.train() 
158
        auc = [results[0] for results in results_lst]
159
        f1score = [results[1] for results in results_lst]
160
        prauc_score = [results[2] for results in results_lst]
161
        print("PR-AUC   mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6])
162
        print("F1       mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6])
163
        print("ROC-AUC  mean: "+str(np.mean(auc))[:6], "std: "+str(np.std(auc))[:6])
164
165
        for nctid, label, predict in zip(nctid_all, label_all, predict_all):
166
            if (predict > 0.5 and label == 0) or (predict < 0.5 and label == 1):
167
                print(nctid, label, str(predict)[:6])
168
169
        nctid2predict = {nctid:predict for nctid, predict in zip(nctid_all, predict_all)} 
170
        pickle.dump(nctid2predict, open('results/nctid2predict.pkl', 'wb'))
171
        return nctid_all, predict_all 
172
173
    def ongoing_test(self, dataloader, sample_num = 20):
174
        self.eval()
175
        best_threshold = 0.5 
176
        whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) 
177
        self.train() 
178
        return nctid_all, predict_all 
179
        
180
    def test(self, dataloader, return_loss = True, validloader=None):
181
        # if validloader is not None:
182
        #   best_threshold = self.select_threshold_for_binary(validloader)
183
        self.eval()
184
        best_threshold = 0.5 
185
        whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader)
186
        # from HINT.utils import plot_hist
187
        # plt.clf()
188
        # prefix_name = "./figure/" + self.save_name 
189
        # plot_hist(prefix_name, predict_all, label_all)
190
        self.train()
191
        if return_loss:
192
            return whole_loss, predict_all, label_all
193
        else:
194
            print_num = 6
195
            auc_score, f1score, prauc_score, precision, recall, accuracy, \
196
            predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold)
197
            print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \
198
                 + "\nPR-AUC: " + str(prauc_score)[:print_num] \
199
                 + "\nPrecision: " + str(precision)[:print_num] \
200
                 + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \
201
                 + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \
202
                 + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num])
203
            return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
204
205
    def learn(self, train_loader, valid_loader, test_loader):
206
        opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
207
        train_loss_record = [] 
208
        valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True)
209
        valid_loss_record = [valid_loss]
210
        best_valid_loss = valid_loss
211
        best_model = deepcopy(self)
212
        train_output = []
213
        valid_output = []
214
        for ep in tqdm(range(self.epoch)):
215
            for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader:
216
                label_vec = label_vec.to(self.device)
217
                output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1)  #### 32, 1 -> 32, ||  label_vec 32,
218
                loss = self.loss(output, label_vec.float())
219
                train_loss_record.append(loss.item())
220
                train_output.append((loss.item(), output, label_vec))
221
                opt.zero_grad()
222
                loss.backward()
223
                opt.step()
224
            valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True)
225
            valid_loss_record.append(valid_loss)
226
            valid_output.append((valid_loss, valid_predict, valid_label))
227
228
            print(f"valid_loss: {valid_loss}")
229
            print(best_valid_loss)
230
            if valid_loss < best_valid_loss:
231
                best_valid_loss = valid_loss 
232
                best_model = deepcopy(self)
233
234
        self.plot_learning_curve(train_loss_record, valid_loss_record)
235
        self = deepcopy(best_model)
236
        auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader)
237
        return train_output, valid_output
238
239
    def plot_learning_curve(self, train_loss_record, valid_loss_record):
240
        plt.plot(train_loss_record)
241
        plt.savefig("./figure/" + self.save_name + '_train_loss.jpg')
242
        plt.clf() 
243
        plt.plot(valid_loss_record)
244
        plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg')
245
        plt.clf() 
246
247
    def select_threshold_for_binary(self, validloader):
248
        _, prediction, label_all, nctid_all = self.generate_predict(validloader)
249
        best_f1 = 0
250
        for threshold in prediction:
251
            float2binary = lambda x:0 if x<threshold else 1
252
            predict_all = list(map(float2binary, prediction))
253
            f1score = precision_score(label_all, predict_all)        
254
            if f1score > best_f1:
255
                best_f1 = f1score 
256
                best_threshold = threshold
257
        return best_threshold 
258
259
260
class HINTModel_multi(Interaction):
261
262
    def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
263
                    device, 
264
                    global_embed_size, 
265
                    highway_num_layer,
266
                    prefix_name, 
267
                    epoch = 20,
268
                    lr = 3e-4, 
269
                    weight_decay = 0, 
270
                    ):
271
        super(HINTModel_multi, self).__init__(molecule_encoder = molecule_encoder, 
272
                                   disease_encoder = disease_encoder, 
273
                                   protocol_encoder = protocol_encoder, 
274
                                   device = device, 
275
                                   prefix_name = prefix_name, 
276
                                   global_embed_size = global_embed_size, 
277
                                   highway_num_layer = highway_num_layer,
278
                                   epoch = epoch,
279
                                   lr = lr, 
280
                                   weight_decay = weight_decay)
281
        self.pred_nn = nn.Linear(self.global_embed_size, 4)
282
        self.loss = nn.CrossEntropyLoss()
283
284
    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
285
        molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
286
        interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
287
        output = self.pred_nn(interaction_embedding)
288
        return output ### 32, 4
289
290
    def generate_predict(self, dataloader):
291
        whole_loss = 0 
292
        label_all, predict_all = [], []
293
        for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader:
294
            label_vec = label_vec.to(self.device)
295
            output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst) 
296
            loss = self.loss(output, label_vec)
297
            whole_loss += loss.item()
298
            predict_all.extend(torch.argmax(output, 1).tolist())
299
            # predict_all.extend([i.item() for i in torch.sigmoid(output)])
300
            label_all.extend([i.item() for i in label_vec])
301
302
        accuracy = len(list(filter(lambda x:x[0]==x[1], zip(predict_all, label_all)))) / len(label_all)
303
        return whole_loss, predict_all, label_all, accuracy
304
305
    def test(self, dataloader, return_loss = True, validloader=None):
306
        # if validloader is not None:
307
        #   best_threshold = self.select_threshold_for_binary(validloader)
308
        self.eval()
309
        whole_loss, predict_all, label_all, accuracy = self.generate_predict(dataloader)
310
        self.train()
311
        return whole_loss, predict_all, label_all, accuracy
312
        # # from HINT.utils import plot_hist
313
        # # plt.clf()
314
        # # prefix_name = "./figure/" + self.save_name 
315
        # # plot_hist(prefix_name, predict_all, label_all)
316
        # self.train()
317
        # if return_loss:
318
        #   return whole_loss
319
        # else:
320
        #   print_num = 5
321
        #   auc_score, f1score, prauc_score, precision, recall, accuracy, \
322
        #   predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold)
323
        #   print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \
324
        #        + "\nPR-AUC: " + str(prauc_score)[:print_num] \
325
        #        + "\nPrecision: " + str(precision)[:print_num] \
326
        #        + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \
327
        #        + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \
328
        #        + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num])
329
        #   return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
330
331
    def learn(self, train_loader, valid_loader, test_loader):
332
        opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
333
        train_loss_record = []
334
        valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True)
335
        print('accuracy', accuracy)
336
        # valid_loss_record = [valid_loss]
337
        # best_valid_loss = valid_loss
338
        best_model = deepcopy(self)
339
        for ep in tqdm(range(self.epoch)):
340
            self.train() 
341
            for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader:
342
                label_vec = label_vec.to(self.device)
343
                output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst)  #### 32, 1 -> 32, ||  label_vec 32,
344
                # print(label_vec.shape, output.shape, label_vec, output)
345
                loss = self.loss(output, label_vec)
346
                train_loss_record.append(loss.item())
347
                opt.zero_grad() 
348
                loss.backward() 
349
                opt.step()
350
            valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True)
351
            print('accuracy', accuracy)
352
        return predict_all, label_all
353
        #   valid_loss_record.append(valid_loss)
354
        #   if valid_loss < best_valid_loss:
355
        #       best_valid_loss = valid_loss 
356
        #       best_model = deepcopy(self)
357
358
        # self.plot_learning_curve(train_loss_record, valid_loss_record)
359
        # self = deepcopy(best_model)
360
        # auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader)
361
362
363
class HINT_nograph(Interaction):
364
    def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, device, 
365
                    global_embed_size, 
366
                    highway_num_layer,
367
                    prefix_name, 
368
                    epoch = 20,
369
                    lr = 3e-4, 
370
                    weight_decay = 0, ):
371
        super(HINT_nograph, self).__init__(molecule_encoder = molecule_encoder, 
372
                                   disease_encoder = disease_encoder, 
373
                                   protocol_encoder = protocol_encoder,
374
                                   device = device,  
375
                                   global_embed_size = global_embed_size, 
376
                                   prefix_name = prefix_name, 
377
                                   highway_num_layer = highway_num_layer,
378
                                   epoch = epoch,
379
                                   lr = lr, 
380
                                   weight_decay = weight_decay, 
381
                                   ) 
382
        self.save_name = prefix_name + '_HINT_nograph'
383
        ''' ### interaction model 
384
        self.molecule_encoder = molecule_encoder 
385
        self.disease_encoder = disease_encoder 
386
        self.protocol_encoder = protocol_encoder 
387
        self.global_embed_size = global_embed_size 
388
        self.highway_num_layer = highway_num_layer 
389
        self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size
390
        self.epoch = epoch 
391
        self.lr = lr 
392
        self.weight_decay = weight_decay 
393
        self.save_name = save_name
394
395
        self.f = F.relu
396
        self.loss = nn.BCEWithLogitsLoss()
397
398
        ##### NN 
399
        self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size)
400
        self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer)
401
        self.pred_nn = nn.Linear(self.global_embed_size, 1)
402
        '''
403
404
        #### risk of disease 
405
        self.risk_disease_fc = nn.Linear(self.disease_encoder.embedding_size, self.global_embed_size)
406
        self.risk_disease_higway = Highway(self.global_embed_size, self.highway_num_layer)
407
408
        #### augment interaction 
409
        self.augment_interaction_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size)
410
        self.augment_interaction_highway = Highway(self.global_embed_size, self.highway_num_layer)
411
412
        #### ADMET 
413
        self.admet_model = []
414
        for i in range(5):
415
            admet_fc = nn.Linear(self.molecule_encoder.embedding_size, self.global_embed_size).to(device)
416
            admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
417
            self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) 
418
        self.admet_model = nn.ModuleList(self.admet_model)
419
420
        #### PK 
421
        self.pk_fc = nn.Linear(self.global_embed_size*5, self.global_embed_size)
422
        self.pk_highway = Highway(self.global_embed_size, self.highway_num_layer)
423
424
        #### trial node 
425
        self.trial_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size)
426
        self.trial_highway = Highway(self.global_embed_size, self.highway_num_layer)
427
428
        ## self.pred_nn = nn.Linear(self.global_embed_size, 1)
429
430
        self.device = device 
431
        self = self.to(device)
432
433
    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = False):
434
        ### encoder for molecule, disease and protocol
435
        molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
436
        ### interaction 
437
        interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
438
        ### risk of disease 
439
        risk_of_disease_embedding = self.feed_lst_of_module(input_feature = icd_embed, 
440
                                                            lst_of_module = [self.risk_disease_fc, self.risk_disease_higway])
441
        ### augment interaction   
442
        augment_interaction_input = torch.cat([interaction_embedding, risk_of_disease_embedding], 1)
443
        augment_interaction_embedding = self.feed_lst_of_module(input_feature = augment_interaction_input, 
444
                                                                lst_of_module = [self.augment_interaction_fc, self.augment_interaction_highway])
445
        ### admet 
446
        admet_embedding_lst = []
447
        for idx in range(5):
448
            admet_embedding = self.feed_lst_of_module(input_feature = molecule_embed, 
449
                                                      lst_of_module = self.admet_model[idx])
450
            admet_embedding_lst.append(admet_embedding)
451
        ### pk 
452
        pk_input = torch.cat(admet_embedding_lst, 1)
453
        pk_embedding = self.feed_lst_of_module(input_feature = pk_input, 
454
                                               lst_of_module = [self.pk_fc, self.pk_highway])
455
        ### trial 
456
        trial_input = torch.cat([pk_embedding, augment_interaction_embedding], 1)
457
        trial_embedding = self.feed_lst_of_module(input_feature = trial_input, 
458
                                                  lst_of_module = [self.trial_fc, self.trial_highway])
459
        output = self.pred_nn(trial_embedding)
460
        if if_gnn == False:
461
            return output 
462
        else:
463
            embedding_lst = [molecule_embed, icd_embed, protocol_embed, interaction_embedding, risk_of_disease_embedding, \
464
                             augment_interaction_embedding] + admet_embedding_lst + [pk_embedding, trial_embedding]
465
            return embedding_lst
466
467
468
class HINTModel(HINT_nograph):
469
470
    def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
471
                    device, 
472
                    global_embed_size, 
473
                    highway_num_layer,
474
                    prefix_name, 
475
                    gnn_hidden_size, 
476
                    epoch = 20,
477
                    lr = 3e-4, 
478
                    weight_decay = 0,):
479
        super(HINTModel, self).__init__(molecule_encoder = molecule_encoder, 
480
                                   disease_encoder = disease_encoder, 
481
                                   protocol_encoder = protocol_encoder, 
482
                                   device = device, 
483
                                   prefix_name = prefix_name, 
484
                                   global_embed_size = global_embed_size, 
485
                                   highway_num_layer = highway_num_layer,
486
                                   epoch = epoch,
487
                                   lr = lr, 
488
                                   weight_decay = weight_decay)
489
        self.save_name = prefix_name 
490
        self.gnn_hidden_size = gnn_hidden_size 
491
        #### GNN 
492
        self.adj = self.generate_adj()          
493
        self.gnn = GCN(
494
            nfeat = self.global_embed_size,
495
            nhid = self.gnn_hidden_size,
496
            nclass = 1,
497
            dropout = 0.6,
498
            init = 'uniform') 
499
        ### gnn's attention         
500
        self.node_size = self.adj.shape[0]
501
        '''
502
        self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() \
503
                                        if self.adj[i,j]==1 else None  \
504
                                        for j in range(self.node_size)]) \
505
                                        for i in range(self.node_size)])
506
        '''
507
        self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)])
508
        # self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)])
509
510
        '''
511
nn.ModuleList([ nn.ModuleList([nn.Linear(3,2) for j in range(5)] + [None]) for i in range(3)])
512
        '''
513
514
        self.device = device 
515
        self = self.to(device)
516
517
    def generate_adj(self):                                     
518
        ##### consistent with HINT_nograph.forward
519
        lst = ["molecule", "disease", "criteria", 'INTERACTION', 'risk_disease', 'augment_interaction', 'A', 'D', 'M', 'E', 'T', 'PK', "final"]
520
        edge_lst = [("disease", "molecule"), ("disease", "criteria"), ("molecule", "criteria"), 
521
                    ("disease", "INTERACTION"), ("molecule", "INTERACTION"),  ("criteria", "INTERACTION"), 
522
                    ("disease", "risk_disease"), ('risk_disease', 'augment_interaction'), ('INTERACTION', 'augment_interaction'),
523
                    ("molecule", "A"), ("molecule", "D"), ("molecule", "M"), ("molecule", "E"), ("molecule", "T"),
524
                    ('A', 'PK'), ('D', 'PK'), ('M', 'PK'), ('E', 'PK'), ('T', 'PK'), 
525
                    ('augment_interaction', 'final'), ('PK', 'final')]
526
        adj = torch.zeros(len(lst), len(lst))
527
        adj = torch.eye(len(lst)) * len(lst)
528
        num2str = {k:v for k,v in enumerate(lst)}
529
        str2num = {v:k for k,v in enumerate(lst)}
530
        for i,j in edge_lst:
531
            n1,n2 = str2num[i], str2num[j]
532
            adj[n1,n2] = 1
533
            adj[n2,n1] = 1
534
        return adj.to(self.device) 
535
536
    def generate_attention_matrx(self, node_feature_mat):
537
        attention_mat = torch.zeros(self.node_size, self.node_size).to(self.device)
538
        for i in range(self.node_size):
539
            for j in range(self.node_size):
540
                if self.adj[i,j]!=1:
541
                    continue 
542
                feature = torch.cat([node_feature_mat[i].view(1,-1), node_feature_mat[j].view(1,-1)], 1)
543
                attention_model = self.graph_attention_model_mat[i][j]
544
                attention_mat[i,j] = torch.sigmoid(self.feed_lst_of_module(input_feature=feature, lst_of_module=attention_model))
545
        return attention_mat 
546
547
    ##### self.global_embed_size*2 -> 1 
548
    def gnn_attention(self):
549
        highway_nn = Highway(size = self.global_embed_size*2, num_layers = self.highway_num_layer).to(self.device)
550
        highway_fc = nn.Linear(self.global_embed_size*2, 1).to(self.device)
551
        return nn.ModuleList([highway_nn, highway_fc])  
552
553
    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix = False):
554
        embedding_lst = HINT_nograph.forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = True)
555
        ### length is 13, each is 32,50 
556
        batch_size = embedding_lst[0].shape[0]
557
        output_lst = []
558
        if return_attention_matrix:
559
            attention_mat_lst = []
560
        for i in range(batch_size):
561
            node_feature_lst = [embedding[i].view(1,-1) for embedding in embedding_lst]
562
            node_feature_mat = torch.cat(node_feature_lst, 0) ### 13, 50 
563
            attention_mat = self.generate_attention_matrx(node_feature_mat)
564
            output = self.gnn(node_feature_mat, self.adj * attention_mat)
565
            output = output[-1].view(1,-1)
566
            output_lst.append(output)
567
            if return_attention_matrix:
568
                attention_mat_lst.append(attention_mat)
569
        output_mat = torch.cat(output_lst, 0)
570
        if not return_attention_matrix:
571
            return output_mat 
572
        else:
573
            return output_mat, attention_mat_lst
574
575
    def interpret(self, complete_dataloader):
576
        from graph_visualize_interpret import data2graph 
577
        from HINT.utils import replace_strange_symbol
578
        for nctid_lst, status_lst, why_stop_lst, label_vec, phase_lst, \
579
            diseases_lst, icdcode_lst3, drugs_lst, smiles_lst2, criteria_lst in complete_dataloader: 
580
            output, attention_mat_lst = self.forward(smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix=True)
581
            output = output.view(-1)
582
            batch_size = len(nctid_lst)
583
            for i in range(batch_size):
584
                name = '__'.join([nctid_lst[i], status_lst[i], why_stop_lst[i], \
585
                                                        str(label_vec[i].item()), str(torch.sigmoid(output[i]).item())[:5], \
586
                                                        phase_lst[i], diseases_lst[i], drugs_lst[i]])
587
                if len(name) > 150:
588
                    name = name[:250]
589
                name = replace_strange_symbol(name)
590
                name = name.replace('__', '_')
591
                name = name.replace('  ', ' ')
592
                name = 'interpret_result/' + name + '.png'
593
                print(name)
594
                data2graph(attention_matrix = attention_mat_lst[i], adj = self.adj, save_name = name)
595
596
    def init_pretrain(self, admet_model):
597
        self.molecule_encoder = admet_model.molecule_encoder
598
599
    ### generate attention matrix 
600
601
602
class Only_Molecule(Interaction):
603
604
    def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
605
                    global_embed_size, 
606
                    highway_num_layer,
607
                    prefix_name, 
608
                    epoch = 20,
609
                    lr = 3e-4, 
610
                    weight_decay = 0):
611
        super(Only_Molecule, self).__init__(molecule_encoder=molecule_encoder, 
612
                                            disease_encoder=disease_encoder, 
613
                                            protocol_encoder=protocol_encoder, 
614
                                            global_embed_size = global_embed_size, 
615
                                            highway_num_layer = highway_num_layer,
616
                                            prefix_name = prefix_name, 
617
                                            epoch = epoch,
618
                                            lr = lr, 
619
                                            weight_decay = weight_decay,)
620
        self.molecule2out = nn.Linear(self.global_embed_size,1)
621
622
623
    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
624
        molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2)
625
        return self.molecule2out(molecule_embed)
626
627
class Only_Disease(Only_Molecule):
628
629
    def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
630
                    global_embed_size, 
631
                    highway_num_layer,
632
                    prefix_name, 
633
                    epoch = 20,
634
                    lr = 3e-4, 
635
                    weight_decay = 0):
636
        super(Only_Disease, self).__init__(molecule_encoder = molecule_encoder, 
637
                                            disease_encoder=disease_encoder, 
638
                                            protocol_encoder=protocol_encoder, 
639
                                            global_embed_size = global_embed_size, 
640
                                            highway_num_layer = highway_num_layer,
641
                                            prefix_name = prefix_name, 
642
                                            epoch = epoch,
643
                                            lr = lr, 
644
                                            weight_decay = weight_decay,)
645
        self.disease2out = self.molecule2out 
646
647
648
    def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
649
        icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3)
650
        return self.disease2out(icd_embed)
651
652
def dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, global_icd):
653
    ## label_vec: (n,)
654
    y = label_vec 
655
656
    num_icd = len(global_icd)
657
    from HINT.utils import smiles_lst2fp 
658
    fp_lst = [smiles_lst2fp(smiles_lst).reshape(1,-1) for smiles_lst in smiles_lst2]
659
    fp_mat = np.concatenate(fp_lst, 0)
660
    # fp_mat = torch.from_numpy(fp_mat)  ### (n,2048)
661
662
    icdcode_lst = []
663
    for lst2 in icdcode_lst3:
664
        lst = list(reduce(lambda x,y:x+y, lst2))
665
        lst = [i.split('.')[0] for i in lst]
666
        lst = set(lst)  
667
        icd_feature = np.zeros((1,num_icd), np.int32)
668
        for ele in lst:
669
            if ele in global_icd:
670
                idx = global_icd.index(ele)
671
                icd_feature[0,idx] = 1 
672
        icdcode_lst.append(icd_feature)
673
    icdcode_mat = np.concatenate(icdcode_lst, 0)
674
    X = np.concatenate([fp_mat, icdcode_mat], 1)
675
    X = torch.from_numpy(X)
676
    X = X.float()
677
    # icdcode_mat = torch.from_numpy(icdcode_mat) 
678
679
    # X = torch.cat([fp_mat, icdcode_mat], 1)
680
    return X, y 
681
682
683
class FFNN(nn.Sequential):
684
    def __init__(self, molecule_dim, diseasecode_dim, 
685
                    global_icd, 
686
                    protocol_dim = 0,
687
                    prefix_name = 'FFNN', 
688
                    epoch = 10,
689
                    lr = 3e-4, 
690
                    weight_decay = 0, 
691
                    ):
692
        super(FFNN, self).__init__()
693
        self.molecule_dim = molecule_dim 
694
        self.diseasecode_dim = diseasecode_dim 
695
        self.protocol_dim = protocol_dim 
696
        self.prefix_name = prefix_name 
697
        self.epoch = epoch 
698
        self.lr = lr 
699
        self.weight_decay = weight_decay 
700
        self.global_icd = global_icd 
701
        self.num_icd = len(global_icd)
702
703
        self.fc_dims = [self.molecule_dim + self.diseasecode_dim + self.protocol_dim, 2000, 1000, 200, 50, 1]
704
        self.fc_layers = nn.ModuleList([nn.Linear(v,self.fc_dims[i+1]) for i,v in enumerate(self.fc_dims[:-1])])
705
        self.loss = nn.BCEWithLogitsLoss()
706
        self.save_name = prefix_name 
707
708
    def forward(self, X):
709
        for i in range(len(self.fc_layers) - 1):
710
            fc_layer = self.fc_layers[i]
711
            X = fc_layer(X)
712
        last_layer = self.fc_layers[-1]
713
        pred = F.sigmoid(last_layer(X))
714
        return pred 
715
716
    def learn(self, train_loader, valid_loader, test_loader):
717
        opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
718
        train_loss_record = [] 
719
        valid_loss = self.test(valid_loader, return_loss=True)
720
        valid_loss_record = [valid_loss]
721
        best_valid_loss = valid_loss
722
        best_model = deepcopy(self)
723
724
        for ep in tqdm(range(self.epoch)):
725
            for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader:
726
                X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd)
727
                output = self.forward(X).view(-1)  #### 32, 1 -> 32, ||  label_vec 32,
728
                loss = self.loss(output, label_vec.float())
729
                train_loss_record.append(loss.item())
730
                opt.zero_grad() 
731
                loss.backward() 
732
                opt.step()
733
            valid_loss = self.test(valid_loader, return_loss=True)
734
            valid_loss_record.append(valid_loss)
735
            if valid_loss < best_valid_loss:
736
                best_valid_loss = valid_loss 
737
                best_model = deepcopy(self)
738
739
        self.plot_learning_curve(train_loss_record, valid_loss_record)
740
        self = deepcopy(best_model)
741
        auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader)
742
743
    def evaluation(self, predict_all, label_all, threshold = 0.5):
744
        import pickle, os
745
        from sklearn.metrics import roc_curve, precision_recall_curve
746
        with open("predict_label.txt", 'w') as fout:
747
            for i,j in zip(predict_all, label_all):
748
                fout.write(str(i)[:4] + '\t' + str(j)[:4]+'\n')
749
        auc_score = roc_auc_score(label_all, predict_all)
750
        figure_folder = "figure"
751
        #### ROC-curve 
752
        fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1)
753
        # roc_curve =plt.figure()
754
        # plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ')
755
        # plt.legend(fontsize = 15)
756
        #plt.savefig(os.path.join(figure_folder,name+"_roc_curve.png"))
757
        #### PR-curve
758
        precision, recall, thresholds = precision_recall_curve(label_all, predict_all)
759
        # plt.plot(recall,precision, label = self.save_name + ' PR Curve')
760
        # plt.legend(fontsize = 15)
761
        # plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png"))
762
        label_all = [int(i) for i in label_all]
763
        float2binary = lambda x:0 if x<threshold else 1
764
        predict_all = list(map(float2binary, predict_all))
765
        f1score = f1_score(label_all, predict_all)
766
        prauc_score = average_precision_score(label_all, predict_all)
767
        # print(predict_all)
768
        precision = precision_score(label_all, predict_all)
769
        recall = recall_score(label_all, predict_all)
770
        accuracy = accuracy_score(label_all, predict_all)
771
        predict_1_ratio = sum(predict_all) / len(predict_all)
772
        label_1_ratio = sum(label_all) / len(label_all)
773
        return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
774
775
    def generate_predict(self, dataloader):
776
        whole_loss = 0 
777
        label_all, predict_all = [], []
778
        for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader:
779
            X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) 
780
            output = self.forward(X).view(-1)  
781
            loss = self.loss(output, label_vec.float())
782
            whole_loss += loss.item()
783
            predict_all.extend([i.item() for i in torch.sigmoid(output)])
784
            label_all.extend([i.item() for i in label_vec])
785
786
        return whole_loss, predict_all, label_all
787
788
    def bootstrap_test(self, dataloader, validloader = None, sample_num = 20):
789
        best_threshold = 0.5
790
        # if validloader is not None:
791
        #   best_threshold = self.select_threshold_for_binary(validloader)
792
        self.eval()
793
        whole_loss, predict_all, label_all = self.generate_predict(dataloader)
794
        from HINT.utils import plot_hist
795
        plt.clf()
796
        prefix_name = "./figure/" + self.save_name 
797
        plot_hist(prefix_name, predict_all, label_all)      
798
        def bootstrap(length, sample_num):
799
            idx = [i for i in range(length)]
800
            from random import choices 
801
            bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)]
802
            return bootstrap_idx 
803
        results_lst = []
804
        bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num)
805
        for bootstrap_idx in bootstrap_idx_lst: 
806
            bootstrap_label = [label_all[idx] for idx in bootstrap_idx]     
807
            bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx]
808
            results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold)
809
            results_lst.append(results)
810
        self.train() 
811
        auc = [results[0] for results in results_lst]
812
        f1score = [results[1] for results in results_lst]
813
        prauc_score = [results[2] for results in results_lst]
814
        print("PR-AUC   mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6])
815
        print("F1       mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6])
816
        print("ROC-AUC  mean: "+ str(np.mean(auc))[:6], "std: " + str(np.std(auc))[:6])
817
818
    def test(self, dataloader, return_loss = True, validloader=None):
819
        # if validloader is not None:
820
        #   best_threshold = self.select_threshold_for_binary(validloader)
821
        self.eval()
822
        best_threshold = 0.5 
823
        whole_loss, predict_all, label_all = self.generate_predict(dataloader)
824
        # from HINT.utils import plot_hist
825
        # plt.clf()
826
        # prefix_name = "./figure/" + self.save_name 
827
        # plot_hist(prefix_name, predict_all, label_all)
828
        self.train()
829
        if return_loss:
830
            return whole_loss
831
        else:
832
            print_num = 5
833
            auc_score, f1score, prauc_score, precision, recall, accuracy, \
834
            predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold)
835
            print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \
836
                 + "\nPR-AUC: " + str(prauc_score)[:print_num] \
837
                 + "\nPrecision: " + str(precision)[:print_num] \
838
                 + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \
839
                 + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \
840
                 + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num])
841
            return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
842
843
    def plot_learning_curve(self, train_loss_record, valid_loss_record):
844
        plt.plot(train_loss_record)
845
        plt.savefig("./figure/" + self.save_name + '_train_loss.jpg')
846
        plt.clf() 
847
        plt.plot(valid_loss_record)
848
        plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg')
849
        plt.clf() 
850
851
852
class ADMET(nn.Sequential):
853
    def __init__(self, mpnn_model, device):
854
        super(ADMET, self).__init__()
855
        self.num = 5 
856
        self.mpnn_model = mpnn_model 
857
        self.device = device 
858
        self.mpnn_dim = mpnn_model.mpnn_hidden_size 
859
        self.admet_model = []
860
        self.global_embed_size = self.mpnn_dim 
861
        self.highway_num_layer = 2 
862
        for i in range(5):
863
            admet_fc = nn.Linear(self.mpnn_model.mpnn_hidden_size, self.global_embed_size).to(device)
864
            admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
865
            self.admet_model.append(nn.ModuleList([admet_fc, admet_highway]))
866
        self.admet_model = nn.ModuleList(self.admet_model)
867
868
        self.admet_pred = nn.ModuleList([nn.Linear(self.global_embed_size,1).to(device) for i in range(5)])
869
        self.f = F.relu 
870
871
        self.device = device 
872
        self = self.to(device)
873
874
    def feed_lst_of_module(self, input_feature, lst_of_module):
875
        x = input_feature
876
        for single_module in lst_of_module:
877
            x = self.f(single_module(x))
878
        return x 
879
880
    def forward(self, smiles_lst, idx):
881
        assert idx in list(range(5))
882
        '''
883
            xxxxxxxxxxxx
884
        '''
885
        embeds = self.mpnn_model.forward_smiles_lst_lst(smiles_lst)
886
        embeds = self.feed_lst_of_module(embeds, self.admet_model[idx]) 
887
        output = self.admet_pred[idx](embeds)
888
        return output 
889
890
    def test(self, valid_loader):
891
        pass 
892
893
    def learn(self, train_loader, valid_loader, idx):
894
        opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
895
        train_loss_record = [] 
896
        valid_loss = self.test(valid_loader, return_loss=True)
897
        valid_loss_record = [valid_loss]
898
        best_valid_loss = valid_loss
899
        best_model = deepcopy(self)
900
901
        for ep in tqdm(range(self.epoch)):
902
            for smiles_lst in train_loader:
903
                output = self.forward(smiles_lst).view(-1)  #### 32, 1 -> 32, ||  label_vec 32,
904
                loss = self.loss(output, label_vec.float())
905
                train_loss_record.append(loss.item())
906
                opt.zero_grad() 
907
                loss.backward() 
908
                opt.step()
909
            valid_loss = self.test(valid_loader, return_loss=True)
910
            valid_loss_record.append(valid_loss)
911
            if valid_loss < best_valid_loss:
912
                best_valid_loss = valid_loss 
913
                best_model = deepcopy(self)
914
915
        self = deepcopy(best_model)