--- a +++ b/HINT/model.py @@ -0,0 +1,915 @@ +from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score +import matplotlib.pyplot as plt +from copy import deepcopy +import numpy as np +from tqdm import tqdm +import torch +torch.manual_seed(0) +from torch import nn +from torch.autograd import Variable +import torch.nn.functional as F +from HINT.module import Highway, GCN +from functools import reduce +import pickle + + +class Interaction(nn.Sequential): + def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, + device, + global_embed_size, + highway_num_layer, + prefix_name, + epoch = 20, + lr = 3e-4, + weight_decay = 0, + ): + super(Interaction, self).__init__() + self.molecule_encoder = molecule_encoder + self.disease_encoder = disease_encoder + self.protocol_encoder = protocol_encoder + self.global_embed_size = global_embed_size + self.highway_num_layer = highway_num_layer + self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size + self.epoch = epoch + self.lr = lr + self.weight_decay = weight_decay + self.save_name = prefix_name + '_interaction' + + self.f = F.relu + self.loss = nn.BCEWithLogitsLoss() + + ##### NN + self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size).to(device) + self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device) + self.pred_nn = nn.Linear(self.global_embed_size, 1) + + self.device = device + self = self.to(device) + + def feed_lst_of_module(self, input_feature, lst_of_module): + x = input_feature + for single_module in lst_of_module: + x = self.f(single_module(x)) + return x + + def forward_get_three_encoders(self, smiles_lst2, icdcode_lst3, criteria_lst): + molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2) + icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3) + protocol_embed = self.protocol_encoder.forward(criteria_lst) + return molecule_embed, icd_embed, protocol_embed + + def forward_encoder_2_interaction(self, molecule_embed, icd_embed, protocol_embed): + encoder_embedding = torch.cat([molecule_embed, icd_embed, protocol_embed], 1) + # interaction_embedding = self.feed_lst_of_module(encoder_embedding, [self.encoder2interaction_fc, self.encoder2interaction_highway]) + h = self.encoder2interaction_fc(encoder_embedding) + h = self.f(h) + h = self.encoder2interaction_highway(h) + interaction_embedding = self.f(h) + return interaction_embedding + + def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): + molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst) + interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed) + output = self.pred_nn(interaction_embedding) + return output ### 32, 1 + + def evaluation(self, predict_all, label_all, threshold = 0.5): + import pickle, os + from sklearn.metrics import roc_curve, precision_recall_curve + with open("predict_label.txt", 'w') as fout: + for i,j in zip(predict_all, label_all): + fout.write(str(i)[:6] + '\t' + str(j)[:4]+'\n') + auc_score = roc_auc_score(label_all, predict_all) + figure_folder = "figure" + #### ROC-curve + fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1) + # roc_curve =plt.figure() + # plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ') + # plt.legend(fontsize = 15) + # plt.savefig(os.path.join(figure_folder,self.save_name+"_roc_curve.png")) + #### PR-curve + precision, recall, thresholds = precision_recall_curve(label_all, predict_all) + # plt.plot(recall,precision, label = self.save_name + ' PR Curve') + # plt.legend(fontsize = 15) + # plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png")) + label_all = [int(i) for i in label_all] + float2binary = lambda x:0 if x < threshold else 1 + predict_all = list(map(float2binary, predict_all)) + f1score = f1_score(label_all, predict_all) + prauc_score = average_precision_score(label_all, predict_all) + # print(predict_all) + precision = precision_score(label_all, predict_all) + recall = recall_score(label_all, predict_all) + accuracy = accuracy_score(label_all, predict_all) + predict_1_ratio = sum(predict_all) / len(predict_all) + label_1_ratio = sum(label_all) / len(label_all) + return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio + + def testloader_to_lst(self, dataloader): + nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst = [], [], [], [], [] + for nctid, label, smiles, icdcode, criteria in dataloader: + nctid_lst.extend(nctid) + label_lst.extend([i.item() for i in label]) + smiles_lst2.extend(smiles) + icdcode_lst3.extend(icdcode) + criteria_lst.extend(criteria) + length = len(nctid_lst) + assert length == len(smiles_lst2) and length == len(icdcode_lst3) + return nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst, length + + def generate_predict(self, dataloader): + whole_loss = 0 + label_all, predict_all, nctid_all = [], [], [] + for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader: + nctid_all.extend(nctid_lst) + label_vec = label_vec.to(self.device) + output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1) + loss = self.loss(output, label_vec.float()) + whole_loss += loss.item() + predict_all.extend([i.item() for i in torch.sigmoid(output)]) + label_all.extend([i.item() for i in label_vec]) + + return whole_loss, predict_all, label_all, nctid_all + + def bootstrap_test(self, dataloader, valid_loader = None, sample_num = 20): + best_threshold = 0.5 + # if validloader is not None: + # best_threshold = self.select_threshold_for_binary(valid_loader) + # print(f"best_threshold: {best_threshold}") + self.eval() + whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) + from HINT.utils import plot_hist + plt.clf() + prefix_name = "./figure/" + self.save_name + plot_hist(prefix_name, predict_all, label_all) + def bootstrap(length, sample_num): + idx = [i for i in range(length)] + from random import choices + bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)] + return bootstrap_idx + results_lst = [] + bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num) + for bootstrap_idx in bootstrap_idx_lst: + bootstrap_label = [label_all[idx] for idx in bootstrap_idx] + bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx] + results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold) + results_lst.append(results) + self.train() + auc = [results[0] for results in results_lst] + f1score = [results[1] for results in results_lst] + prauc_score = [results[2] for results in results_lst] + print("PR-AUC mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6]) + print("F1 mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6]) + print("ROC-AUC mean: "+str(np.mean(auc))[:6], "std: "+str(np.std(auc))[:6]) + + for nctid, label, predict in zip(nctid_all, label_all, predict_all): + if (predict > 0.5 and label == 0) or (predict < 0.5 and label == 1): + print(nctid, label, str(predict)[:6]) + + nctid2predict = {nctid:predict for nctid, predict in zip(nctid_all, predict_all)} + pickle.dump(nctid2predict, open('results/nctid2predict.pkl', 'wb')) + return nctid_all, predict_all + + def ongoing_test(self, dataloader, sample_num = 20): + self.eval() + best_threshold = 0.5 + whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) + self.train() + return nctid_all, predict_all + + def test(self, dataloader, return_loss = True, validloader=None): + # if validloader is not None: + # best_threshold = self.select_threshold_for_binary(validloader) + self.eval() + best_threshold = 0.5 + whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) + # from HINT.utils import plot_hist + # plt.clf() + # prefix_name = "./figure/" + self.save_name + # plot_hist(prefix_name, predict_all, label_all) + self.train() + if return_loss: + return whole_loss, predict_all, label_all + else: + print_num = 6 + auc_score, f1score, prauc_score, precision, recall, accuracy, \ + predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold) + print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \ + + "\nPR-AUC: " + str(prauc_score)[:print_num] \ + + "\nPrecision: " + str(precision)[:print_num] \ + + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \ + + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \ + + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num]) + return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio + + def learn(self, train_loader, valid_loader, test_loader): + opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) + train_loss_record = [] + valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True) + valid_loss_record = [valid_loss] + best_valid_loss = valid_loss + best_model = deepcopy(self) + train_output = [] + valid_output = [] + for ep in tqdm(range(self.epoch)): + for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader: + label_vec = label_vec.to(self.device) + output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1) #### 32, 1 -> 32, || label_vec 32, + loss = self.loss(output, label_vec.float()) + train_loss_record.append(loss.item()) + train_output.append((loss.item(), output, label_vec)) + opt.zero_grad() + loss.backward() + opt.step() + valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True) + valid_loss_record.append(valid_loss) + valid_output.append((valid_loss, valid_predict, valid_label)) + + print(f"valid_loss: {valid_loss}") + print(best_valid_loss) + if valid_loss < best_valid_loss: + best_valid_loss = valid_loss + best_model = deepcopy(self) + + self.plot_learning_curve(train_loss_record, valid_loss_record) + self = deepcopy(best_model) + auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader) + return train_output, valid_output + + def plot_learning_curve(self, train_loss_record, valid_loss_record): + plt.plot(train_loss_record) + plt.savefig("./figure/" + self.save_name + '_train_loss.jpg') + plt.clf() + plt.plot(valid_loss_record) + plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg') + plt.clf() + + def select_threshold_for_binary(self, validloader): + _, prediction, label_all, nctid_all = self.generate_predict(validloader) + best_f1 = 0 + for threshold in prediction: + float2binary = lambda x:0 if x<threshold else 1 + predict_all = list(map(float2binary, prediction)) + f1score = precision_score(label_all, predict_all) + if f1score > best_f1: + best_f1 = f1score + best_threshold = threshold + return best_threshold + + +class HINTModel_multi(Interaction): + + def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, + device, + global_embed_size, + highway_num_layer, + prefix_name, + epoch = 20, + lr = 3e-4, + weight_decay = 0, + ): + super(HINTModel_multi, self).__init__(molecule_encoder = molecule_encoder, + disease_encoder = disease_encoder, + protocol_encoder = protocol_encoder, + device = device, + prefix_name = prefix_name, + global_embed_size = global_embed_size, + highway_num_layer = highway_num_layer, + epoch = epoch, + lr = lr, + weight_decay = weight_decay) + self.pred_nn = nn.Linear(self.global_embed_size, 4) + self.loss = nn.CrossEntropyLoss() + + def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): + molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst) + interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed) + output = self.pred_nn(interaction_embedding) + return output ### 32, 4 + + def generate_predict(self, dataloader): + whole_loss = 0 + label_all, predict_all = [], [] + for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader: + label_vec = label_vec.to(self.device) + output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst) + loss = self.loss(output, label_vec) + whole_loss += loss.item() + predict_all.extend(torch.argmax(output, 1).tolist()) + # predict_all.extend([i.item() for i in torch.sigmoid(output)]) + label_all.extend([i.item() for i in label_vec]) + + accuracy = len(list(filter(lambda x:x[0]==x[1], zip(predict_all, label_all)))) / len(label_all) + return whole_loss, predict_all, label_all, accuracy + + def test(self, dataloader, return_loss = True, validloader=None): + # if validloader is not None: + # best_threshold = self.select_threshold_for_binary(validloader) + self.eval() + whole_loss, predict_all, label_all, accuracy = self.generate_predict(dataloader) + self.train() + return whole_loss, predict_all, label_all, accuracy + # # from HINT.utils import plot_hist + # # plt.clf() + # # prefix_name = "./figure/" + self.save_name + # # plot_hist(prefix_name, predict_all, label_all) + # self.train() + # if return_loss: + # return whole_loss + # else: + # print_num = 5 + # auc_score, f1score, prauc_score, precision, recall, accuracy, \ + # predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold) + # print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \ + # + "\nPR-AUC: " + str(prauc_score)[:print_num] \ + # + "\nPrecision: " + str(precision)[:print_num] \ + # + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \ + # + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \ + # + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num]) + # return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio + + def learn(self, train_loader, valid_loader, test_loader): + opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) + train_loss_record = [] + valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True) + print('accuracy', accuracy) + # valid_loss_record = [valid_loss] + # best_valid_loss = valid_loss + best_model = deepcopy(self) + for ep in tqdm(range(self.epoch)): + self.train() + for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader: + label_vec = label_vec.to(self.device) + output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst) #### 32, 1 -> 32, || label_vec 32, + # print(label_vec.shape, output.shape, label_vec, output) + loss = self.loss(output, label_vec) + train_loss_record.append(loss.item()) + opt.zero_grad() + loss.backward() + opt.step() + valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True) + print('accuracy', accuracy) + return predict_all, label_all + # valid_loss_record.append(valid_loss) + # if valid_loss < best_valid_loss: + # best_valid_loss = valid_loss + # best_model = deepcopy(self) + + # self.plot_learning_curve(train_loss_record, valid_loss_record) + # self = deepcopy(best_model) + # auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader) + + +class HINT_nograph(Interaction): + def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, device, + global_embed_size, + highway_num_layer, + prefix_name, + epoch = 20, + lr = 3e-4, + weight_decay = 0, ): + super(HINT_nograph, self).__init__(molecule_encoder = molecule_encoder, + disease_encoder = disease_encoder, + protocol_encoder = protocol_encoder, + device = device, + global_embed_size = global_embed_size, + prefix_name = prefix_name, + highway_num_layer = highway_num_layer, + epoch = epoch, + lr = lr, + weight_decay = weight_decay, + ) + self.save_name = prefix_name + '_HINT_nograph' + ''' ### interaction model + self.molecule_encoder = molecule_encoder + self.disease_encoder = disease_encoder + self.protocol_encoder = protocol_encoder + self.global_embed_size = global_embed_size + self.highway_num_layer = highway_num_layer + self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size + self.epoch = epoch + self.lr = lr + self.weight_decay = weight_decay + self.save_name = save_name + + self.f = F.relu + self.loss = nn.BCEWithLogitsLoss() + + ##### NN + self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size) + self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer) + self.pred_nn = nn.Linear(self.global_embed_size, 1) + ''' + + #### risk of disease + self.risk_disease_fc = nn.Linear(self.disease_encoder.embedding_size, self.global_embed_size) + self.risk_disease_higway = Highway(self.global_embed_size, self.highway_num_layer) + + #### augment interaction + self.augment_interaction_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size) + self.augment_interaction_highway = Highway(self.global_embed_size, self.highway_num_layer) + + #### ADMET + self.admet_model = [] + for i in range(5): + admet_fc = nn.Linear(self.molecule_encoder.embedding_size, self.global_embed_size).to(device) + admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device) + self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) + self.admet_model = nn.ModuleList(self.admet_model) + + #### PK + self.pk_fc = nn.Linear(self.global_embed_size*5, self.global_embed_size) + self.pk_highway = Highway(self.global_embed_size, self.highway_num_layer) + + #### trial node + self.trial_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size) + self.trial_highway = Highway(self.global_embed_size, self.highway_num_layer) + + ## self.pred_nn = nn.Linear(self.global_embed_size, 1) + + self.device = device + self = self.to(device) + + def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = False): + ### encoder for molecule, disease and protocol + molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst) + ### interaction + interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed) + ### risk of disease + risk_of_disease_embedding = self.feed_lst_of_module(input_feature = icd_embed, + lst_of_module = [self.risk_disease_fc, self.risk_disease_higway]) + ### augment interaction + augment_interaction_input = torch.cat([interaction_embedding, risk_of_disease_embedding], 1) + augment_interaction_embedding = self.feed_lst_of_module(input_feature = augment_interaction_input, + lst_of_module = [self.augment_interaction_fc, self.augment_interaction_highway]) + ### admet + admet_embedding_lst = [] + for idx in range(5): + admet_embedding = self.feed_lst_of_module(input_feature = molecule_embed, + lst_of_module = self.admet_model[idx]) + admet_embedding_lst.append(admet_embedding) + ### pk + pk_input = torch.cat(admet_embedding_lst, 1) + pk_embedding = self.feed_lst_of_module(input_feature = pk_input, + lst_of_module = [self.pk_fc, self.pk_highway]) + ### trial + trial_input = torch.cat([pk_embedding, augment_interaction_embedding], 1) + trial_embedding = self.feed_lst_of_module(input_feature = trial_input, + lst_of_module = [self.trial_fc, self.trial_highway]) + output = self.pred_nn(trial_embedding) + if if_gnn == False: + return output + else: + embedding_lst = [molecule_embed, icd_embed, protocol_embed, interaction_embedding, risk_of_disease_embedding, \ + augment_interaction_embedding] + admet_embedding_lst + [pk_embedding, trial_embedding] + return embedding_lst + + +class HINTModel(HINT_nograph): + + def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, + device, + global_embed_size, + highway_num_layer, + prefix_name, + gnn_hidden_size, + epoch = 20, + lr = 3e-4, + weight_decay = 0,): + super(HINTModel, self).__init__(molecule_encoder = molecule_encoder, + disease_encoder = disease_encoder, + protocol_encoder = protocol_encoder, + device = device, + prefix_name = prefix_name, + global_embed_size = global_embed_size, + highway_num_layer = highway_num_layer, + epoch = epoch, + lr = lr, + weight_decay = weight_decay) + self.save_name = prefix_name + self.gnn_hidden_size = gnn_hidden_size + #### GNN + self.adj = self.generate_adj() + self.gnn = GCN( + nfeat = self.global_embed_size, + nhid = self.gnn_hidden_size, + nclass = 1, + dropout = 0.6, + init = 'uniform') + ### gnn's attention + self.node_size = self.adj.shape[0] + ''' + self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() \ + if self.adj[i,j]==1 else None \ + for j in range(self.node_size)]) \ + for i in range(self.node_size)]) + ''' + self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)]) + # self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)]) + + ''' +nn.ModuleList([ nn.ModuleList([nn.Linear(3,2) for j in range(5)] + [None]) for i in range(3)]) + ''' + + self.device = device + self = self.to(device) + + def generate_adj(self): + ##### consistent with HINT_nograph.forward + lst = ["molecule", "disease", "criteria", 'INTERACTION', 'risk_disease', 'augment_interaction', 'A', 'D', 'M', 'E', 'T', 'PK', "final"] + edge_lst = [("disease", "molecule"), ("disease", "criteria"), ("molecule", "criteria"), + ("disease", "INTERACTION"), ("molecule", "INTERACTION"), ("criteria", "INTERACTION"), + ("disease", "risk_disease"), ('risk_disease', 'augment_interaction'), ('INTERACTION', 'augment_interaction'), + ("molecule", "A"), ("molecule", "D"), ("molecule", "M"), ("molecule", "E"), ("molecule", "T"), + ('A', 'PK'), ('D', 'PK'), ('M', 'PK'), ('E', 'PK'), ('T', 'PK'), + ('augment_interaction', 'final'), ('PK', 'final')] + adj = torch.zeros(len(lst), len(lst)) + adj = torch.eye(len(lst)) * len(lst) + num2str = {k:v for k,v in enumerate(lst)} + str2num = {v:k for k,v in enumerate(lst)} + for i,j in edge_lst: + n1,n2 = str2num[i], str2num[j] + adj[n1,n2] = 1 + adj[n2,n1] = 1 + return adj.to(self.device) + + def generate_attention_matrx(self, node_feature_mat): + attention_mat = torch.zeros(self.node_size, self.node_size).to(self.device) + for i in range(self.node_size): + for j in range(self.node_size): + if self.adj[i,j]!=1: + continue + feature = torch.cat([node_feature_mat[i].view(1,-1), node_feature_mat[j].view(1,-1)], 1) + attention_model = self.graph_attention_model_mat[i][j] + attention_mat[i,j] = torch.sigmoid(self.feed_lst_of_module(input_feature=feature, lst_of_module=attention_model)) + return attention_mat + + ##### self.global_embed_size*2 -> 1 + def gnn_attention(self): + highway_nn = Highway(size = self.global_embed_size*2, num_layers = self.highway_num_layer).to(self.device) + highway_fc = nn.Linear(self.global_embed_size*2, 1).to(self.device) + return nn.ModuleList([highway_nn, highway_fc]) + + def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix = False): + embedding_lst = HINT_nograph.forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = True) + ### length is 13, each is 32,50 + batch_size = embedding_lst[0].shape[0] + output_lst = [] + if return_attention_matrix: + attention_mat_lst = [] + for i in range(batch_size): + node_feature_lst = [embedding[i].view(1,-1) for embedding in embedding_lst] + node_feature_mat = torch.cat(node_feature_lst, 0) ### 13, 50 + attention_mat = self.generate_attention_matrx(node_feature_mat) + output = self.gnn(node_feature_mat, self.adj * attention_mat) + output = output[-1].view(1,-1) + output_lst.append(output) + if return_attention_matrix: + attention_mat_lst.append(attention_mat) + output_mat = torch.cat(output_lst, 0) + if not return_attention_matrix: + return output_mat + else: + return output_mat, attention_mat_lst + + def interpret(self, complete_dataloader): + from graph_visualize_interpret import data2graph + from HINT.utils import replace_strange_symbol + for nctid_lst, status_lst, why_stop_lst, label_vec, phase_lst, \ + diseases_lst, icdcode_lst3, drugs_lst, smiles_lst2, criteria_lst in complete_dataloader: + output, attention_mat_lst = self.forward(smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix=True) + output = output.view(-1) + batch_size = len(nctid_lst) + for i in range(batch_size): + name = '__'.join([nctid_lst[i], status_lst[i], why_stop_lst[i], \ + str(label_vec[i].item()), str(torch.sigmoid(output[i]).item())[:5], \ + phase_lst[i], diseases_lst[i], drugs_lst[i]]) + if len(name) > 150: + name = name[:250] + name = replace_strange_symbol(name) + name = name.replace('__', '_') + name = name.replace(' ', ' ') + name = 'interpret_result/' + name + '.png' + print(name) + data2graph(attention_matrix = attention_mat_lst[i], adj = self.adj, save_name = name) + + def init_pretrain(self, admet_model): + self.molecule_encoder = admet_model.molecule_encoder + + ### generate attention matrix + + +class Only_Molecule(Interaction): + + def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, + global_embed_size, + highway_num_layer, + prefix_name, + epoch = 20, + lr = 3e-4, + weight_decay = 0): + super(Only_Molecule, self).__init__(molecule_encoder=molecule_encoder, + disease_encoder=disease_encoder, + protocol_encoder=protocol_encoder, + global_embed_size = global_embed_size, + highway_num_layer = highway_num_layer, + prefix_name = prefix_name, + epoch = epoch, + lr = lr, + weight_decay = weight_decay,) + self.molecule2out = nn.Linear(self.global_embed_size,1) + + + def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): + molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2) + return self.molecule2out(molecule_embed) + +class Only_Disease(Only_Molecule): + + def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, + global_embed_size, + highway_num_layer, + prefix_name, + epoch = 20, + lr = 3e-4, + weight_decay = 0): + super(Only_Disease, self).__init__(molecule_encoder = molecule_encoder, + disease_encoder=disease_encoder, + protocol_encoder=protocol_encoder, + global_embed_size = global_embed_size, + highway_num_layer = highway_num_layer, + prefix_name = prefix_name, + epoch = epoch, + lr = lr, + weight_decay = weight_decay,) + self.disease2out = self.molecule2out + + + def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): + icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3) + return self.disease2out(icd_embed) + +def dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, global_icd): + ## label_vec: (n,) + y = label_vec + + num_icd = len(global_icd) + from HINT.utils import smiles_lst2fp + fp_lst = [smiles_lst2fp(smiles_lst).reshape(1,-1) for smiles_lst in smiles_lst2] + fp_mat = np.concatenate(fp_lst, 0) + # fp_mat = torch.from_numpy(fp_mat) ### (n,2048) + + icdcode_lst = [] + for lst2 in icdcode_lst3: + lst = list(reduce(lambda x,y:x+y, lst2)) + lst = [i.split('.')[0] for i in lst] + lst = set(lst) + icd_feature = np.zeros((1,num_icd), np.int32) + for ele in lst: + if ele in global_icd: + idx = global_icd.index(ele) + icd_feature[0,idx] = 1 + icdcode_lst.append(icd_feature) + icdcode_mat = np.concatenate(icdcode_lst, 0) + X = np.concatenate([fp_mat, icdcode_mat], 1) + X = torch.from_numpy(X) + X = X.float() + # icdcode_mat = torch.from_numpy(icdcode_mat) + + # X = torch.cat([fp_mat, icdcode_mat], 1) + return X, y + + +class FFNN(nn.Sequential): + def __init__(self, molecule_dim, diseasecode_dim, + global_icd, + protocol_dim = 0, + prefix_name = 'FFNN', + epoch = 10, + lr = 3e-4, + weight_decay = 0, + ): + super(FFNN, self).__init__() + self.molecule_dim = molecule_dim + self.diseasecode_dim = diseasecode_dim + self.protocol_dim = protocol_dim + self.prefix_name = prefix_name + self.epoch = epoch + self.lr = lr + self.weight_decay = weight_decay + self.global_icd = global_icd + self.num_icd = len(global_icd) + + self.fc_dims = [self.molecule_dim + self.diseasecode_dim + self.protocol_dim, 2000, 1000, 200, 50, 1] + self.fc_layers = nn.ModuleList([nn.Linear(v,self.fc_dims[i+1]) for i,v in enumerate(self.fc_dims[:-1])]) + self.loss = nn.BCEWithLogitsLoss() + self.save_name = prefix_name + + def forward(self, X): + for i in range(len(self.fc_layers) - 1): + fc_layer = self.fc_layers[i] + X = fc_layer(X) + last_layer = self.fc_layers[-1] + pred = F.sigmoid(last_layer(X)) + return pred + + def learn(self, train_loader, valid_loader, test_loader): + opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) + train_loss_record = [] + valid_loss = self.test(valid_loader, return_loss=True) + valid_loss_record = [valid_loss] + best_valid_loss = valid_loss + best_model = deepcopy(self) + + for ep in tqdm(range(self.epoch)): + for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader: + X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) + output = self.forward(X).view(-1) #### 32, 1 -> 32, || label_vec 32, + loss = self.loss(output, label_vec.float()) + train_loss_record.append(loss.item()) + opt.zero_grad() + loss.backward() + opt.step() + valid_loss = self.test(valid_loader, return_loss=True) + valid_loss_record.append(valid_loss) + if valid_loss < best_valid_loss: + best_valid_loss = valid_loss + best_model = deepcopy(self) + + self.plot_learning_curve(train_loss_record, valid_loss_record) + self = deepcopy(best_model) + auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader) + + def evaluation(self, predict_all, label_all, threshold = 0.5): + import pickle, os + from sklearn.metrics import roc_curve, precision_recall_curve + with open("predict_label.txt", 'w') as fout: + for i,j in zip(predict_all, label_all): + fout.write(str(i)[:4] + '\t' + str(j)[:4]+'\n') + auc_score = roc_auc_score(label_all, predict_all) + figure_folder = "figure" + #### ROC-curve + fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1) + # roc_curve =plt.figure() + # plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ') + # plt.legend(fontsize = 15) + #plt.savefig(os.path.join(figure_folder,name+"_roc_curve.png")) + #### PR-curve + precision, recall, thresholds = precision_recall_curve(label_all, predict_all) + # plt.plot(recall,precision, label = self.save_name + ' PR Curve') + # plt.legend(fontsize = 15) + # plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png")) + label_all = [int(i) for i in label_all] + float2binary = lambda x:0 if x<threshold else 1 + predict_all = list(map(float2binary, predict_all)) + f1score = f1_score(label_all, predict_all) + prauc_score = average_precision_score(label_all, predict_all) + # print(predict_all) + precision = precision_score(label_all, predict_all) + recall = recall_score(label_all, predict_all) + accuracy = accuracy_score(label_all, predict_all) + predict_1_ratio = sum(predict_all) / len(predict_all) + label_1_ratio = sum(label_all) / len(label_all) + return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio + + def generate_predict(self, dataloader): + whole_loss = 0 + label_all, predict_all = [], [] + for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader: + X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) + output = self.forward(X).view(-1) + loss = self.loss(output, label_vec.float()) + whole_loss += loss.item() + predict_all.extend([i.item() for i in torch.sigmoid(output)]) + label_all.extend([i.item() for i in label_vec]) + + return whole_loss, predict_all, label_all + + def bootstrap_test(self, dataloader, validloader = None, sample_num = 20): + best_threshold = 0.5 + # if validloader is not None: + # best_threshold = self.select_threshold_for_binary(validloader) + self.eval() + whole_loss, predict_all, label_all = self.generate_predict(dataloader) + from HINT.utils import plot_hist + plt.clf() + prefix_name = "./figure/" + self.save_name + plot_hist(prefix_name, predict_all, label_all) + def bootstrap(length, sample_num): + idx = [i for i in range(length)] + from random import choices + bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)] + return bootstrap_idx + results_lst = [] + bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num) + for bootstrap_idx in bootstrap_idx_lst: + bootstrap_label = [label_all[idx] for idx in bootstrap_idx] + bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx] + results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold) + results_lst.append(results) + self.train() + auc = [results[0] for results in results_lst] + f1score = [results[1] for results in results_lst] + prauc_score = [results[2] for results in results_lst] + print("PR-AUC mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6]) + print("F1 mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6]) + print("ROC-AUC mean: "+ str(np.mean(auc))[:6], "std: " + str(np.std(auc))[:6]) + + def test(self, dataloader, return_loss = True, validloader=None): + # if validloader is not None: + # best_threshold = self.select_threshold_for_binary(validloader) + self.eval() + best_threshold = 0.5 + whole_loss, predict_all, label_all = self.generate_predict(dataloader) + # from HINT.utils import plot_hist + # plt.clf() + # prefix_name = "./figure/" + self.save_name + # plot_hist(prefix_name, predict_all, label_all) + self.train() + if return_loss: + return whole_loss + else: + print_num = 5 + auc_score, f1score, prauc_score, precision, recall, accuracy, \ + predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold) + print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \ + + "\nPR-AUC: " + str(prauc_score)[:print_num] \ + + "\nPrecision: " + str(precision)[:print_num] \ + + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \ + + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \ + + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num]) + return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio + + def plot_learning_curve(self, train_loss_record, valid_loss_record): + plt.plot(train_loss_record) + plt.savefig("./figure/" + self.save_name + '_train_loss.jpg') + plt.clf() + plt.plot(valid_loss_record) + plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg') + plt.clf() + + +class ADMET(nn.Sequential): + def __init__(self, mpnn_model, device): + super(ADMET, self).__init__() + self.num = 5 + self.mpnn_model = mpnn_model + self.device = device + self.mpnn_dim = mpnn_model.mpnn_hidden_size + self.admet_model = [] + self.global_embed_size = self.mpnn_dim + self.highway_num_layer = 2 + for i in range(5): + admet_fc = nn.Linear(self.mpnn_model.mpnn_hidden_size, self.global_embed_size).to(device) + admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device) + self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) + self.admet_model = nn.ModuleList(self.admet_model) + + self.admet_pred = nn.ModuleList([nn.Linear(self.global_embed_size,1).to(device) for i in range(5)]) + self.f = F.relu + + self.device = device + self = self.to(device) + + def feed_lst_of_module(self, input_feature, lst_of_module): + x = input_feature + for single_module in lst_of_module: + x = self.f(single_module(x)) + return x + + def forward(self, smiles_lst, idx): + assert idx in list(range(5)) + ''' + xxxxxxxxxxxx + ''' + embeds = self.mpnn_model.forward_smiles_lst_lst(smiles_lst) + embeds = self.feed_lst_of_module(embeds, self.admet_model[idx]) + output = self.admet_pred[idx](embeds) + return output + + def test(self, valid_loader): + pass + + def learn(self, train_loader, valid_loader, idx): + opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) + train_loss_record = [] + valid_loss = self.test(valid_loader, return_loss=True) + valid_loss_record = [valid_loss] + best_valid_loss = valid_loss + best_model = deepcopy(self) + + for ep in tqdm(range(self.epoch)): + for smiles_lst in train_loader: + output = self.forward(smiles_lst).view(-1) #### 32, 1 -> 32, || label_vec 32, + loss = self.loss(output, label_vec.float()) + train_loss_record.append(loss.item()) + opt.zero_grad() + loss.backward() + opt.step() + valid_loss = self.test(valid_loader, return_loss=True) + valid_loss_record.append(valid_loss) + if valid_loss < best_valid_loss: + best_valid_loss = valid_loss + best_model = deepcopy(self) + + self = deepcopy(best_model)