Diff of /HINT/model.py [000000] .. [bc9e98]

Switch to side-by-side view

--- a
+++ b/HINT/model.py
@@ -0,0 +1,915 @@
+from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score
+import matplotlib.pyplot as plt
+from copy import deepcopy 
+import numpy as np 
+from tqdm import tqdm 
+import torch 
+torch.manual_seed(0)
+from torch import nn 
+from torch.autograd import Variable
+import torch.nn.functional as F
+from HINT.module import Highway, GCN 
+from functools import reduce 
+import pickle
+
+
+class Interaction(nn.Sequential):
+	def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
+					device, 
+					global_embed_size, 
+					highway_num_layer,
+					prefix_name, 
+					epoch = 20,
+					lr = 3e-4, 
+					weight_decay = 0, 
+					):
+		super(Interaction, self).__init__()
+		self.molecule_encoder = molecule_encoder 
+		self.disease_encoder = disease_encoder 
+		self.protocol_encoder = protocol_encoder 
+		self.global_embed_size = global_embed_size 
+		self.highway_num_layer = highway_num_layer 
+		self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size
+		self.epoch = epoch 
+		self.lr = lr 
+		self.weight_decay = weight_decay 
+		self.save_name = prefix_name + '_interaction'
+
+		self.f = F.relu
+		self.loss = nn.BCEWithLogitsLoss()
+
+		##### NN 
+		self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size).to(device)
+		self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
+		self.pred_nn = nn.Linear(self.global_embed_size, 1)
+
+		self.device = device 
+		self = self.to(device)
+
+	def feed_lst_of_module(self, input_feature, lst_of_module):
+		x = input_feature
+		for single_module in lst_of_module:
+			x = self.f(single_module(x))
+		return x
+
+	def forward_get_three_encoders(self, smiles_lst2, icdcode_lst3, criteria_lst):
+		molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2)
+		icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3)
+		protocol_embed = self.protocol_encoder.forward(criteria_lst)
+		return molecule_embed, icd_embed, protocol_embed	
+
+	def forward_encoder_2_interaction(self, molecule_embed, icd_embed, protocol_embed):
+		encoder_embedding = torch.cat([molecule_embed, icd_embed, protocol_embed], 1)
+		# interaction_embedding = self.feed_lst_of_module(encoder_embedding, [self.encoder2interaction_fc, self.encoder2interaction_highway])
+		h = self.encoder2interaction_fc(encoder_embedding)
+		h = self.f(h)
+		h = self.encoder2interaction_highway(h)
+		interaction_embedding = self.f(h)
+		return interaction_embedding 
+
+	def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
+		molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
+		interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
+		output = self.pred_nn(interaction_embedding)
+		return output ### 32, 1
+
+	def evaluation(self, predict_all, label_all, threshold = 0.5):
+		import pickle, os
+		from sklearn.metrics import roc_curve, precision_recall_curve
+		with open("predict_label.txt", 'w') as fout:
+			for i,j in zip(predict_all, label_all):
+				fout.write(str(i)[:6] + '\t' + str(j)[:4]+'\n')
+		auc_score = roc_auc_score(label_all, predict_all)
+		figure_folder = "figure"
+		#### ROC-curve 
+		fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1)
+		# roc_curve =plt.figure()
+		# plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ')
+		# plt.legend(fontsize = 15)
+		# plt.savefig(os.path.join(figure_folder,self.save_name+"_roc_curve.png"))
+		#### PR-curve
+		precision, recall, thresholds = precision_recall_curve(label_all, predict_all)
+		# plt.plot(recall,precision, label = self.save_name + ' PR Curve')
+		# plt.legend(fontsize = 15)
+		# plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png"))
+		label_all = [int(i) for i in label_all]
+		float2binary = lambda x:0 if x < threshold else 1
+		predict_all = list(map(float2binary, predict_all))
+		f1score = f1_score(label_all, predict_all)
+		prauc_score = average_precision_score(label_all, predict_all)
+		# print(predict_all)
+		precision = precision_score(label_all, predict_all)
+		recall = recall_score(label_all, predict_all)
+		accuracy = accuracy_score(label_all, predict_all)
+		predict_1_ratio = sum(predict_all) / len(predict_all)
+		label_1_ratio = sum(label_all) / len(label_all)
+		return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
+
+	def testloader_to_lst(self, dataloader):
+		nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst = [], [], [], [], []
+		for nctid, label, smiles, icdcode, criteria in dataloader:
+			nctid_lst.extend(nctid)
+			label_lst.extend([i.item() for i in label])
+			smiles_lst2.extend(smiles)
+			icdcode_lst3.extend(icdcode)
+			criteria_lst.extend(criteria)
+		length = len(nctid_lst)
+		assert length == len(smiles_lst2) and length == len(icdcode_lst3)
+		return nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst, length 
+
+	def generate_predict(self, dataloader):
+		whole_loss = 0 
+		label_all, predict_all, nctid_all = [], [], []
+		for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader:
+			nctid_all.extend(nctid_lst)
+			label_vec = label_vec.to(self.device)
+			output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1)  
+			loss = self.loss(output, label_vec.float())
+			whole_loss += loss.item()
+			predict_all.extend([i.item() for i in torch.sigmoid(output)])
+			label_all.extend([i.item() for i in label_vec])
+
+		return whole_loss, predict_all, label_all, nctid_all
+
+	def bootstrap_test(self, dataloader, valid_loader = None, sample_num = 20):
+		best_threshold = 0.5
+		# if validloader is not None:
+		# 	best_threshold = self.select_threshold_for_binary(valid_loader)
+		# 	print(f"best_threshold: {best_threshold}")
+		self.eval()
+		whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader)
+		from HINT.utils import plot_hist
+		plt.clf()
+		prefix_name = "./figure/" + self.save_name 
+		plot_hist(prefix_name, predict_all, label_all)		
+		def bootstrap(length, sample_num):
+			idx = [i for i in range(length)]
+			from random import choices 
+			bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)]
+			return bootstrap_idx 
+		results_lst = []
+		bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num)
+		for bootstrap_idx in bootstrap_idx_lst: 
+			bootstrap_label = [label_all[idx] for idx in bootstrap_idx]		
+			bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx]
+			results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold)
+			results_lst.append(results)
+		self.train() 
+		auc = [results[0] for results in results_lst]
+		f1score = [results[1] for results in results_lst]
+		prauc_score = [results[2] for results in results_lst]
+		print("PR-AUC   mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6])
+		print("F1       mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6])
+		print("ROC-AUC  mean: "+str(np.mean(auc))[:6], "std: "+str(np.std(auc))[:6])
+
+		for nctid, label, predict in zip(nctid_all, label_all, predict_all):
+			if (predict > 0.5 and label == 0) or (predict < 0.5 and label == 1):
+				print(nctid, label, str(predict)[:6])
+
+		nctid2predict = {nctid:predict for nctid, predict in zip(nctid_all, predict_all)} 
+		pickle.dump(nctid2predict, open('results/nctid2predict.pkl', 'wb'))
+		return nctid_all, predict_all 
+
+	def ongoing_test(self, dataloader, sample_num = 20):
+		self.eval()
+		best_threshold = 0.5 
+		whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) 
+		self.train() 
+		return nctid_all, predict_all 
+		
+	def test(self, dataloader, return_loss = True, validloader=None):
+		# if validloader is not None:
+		# 	best_threshold = self.select_threshold_for_binary(validloader)
+		self.eval()
+		best_threshold = 0.5 
+		whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader)
+		# from HINT.utils import plot_hist
+		# plt.clf()
+		# prefix_name = "./figure/" + self.save_name 
+		# plot_hist(prefix_name, predict_all, label_all)
+		self.train()
+		if return_loss:
+			return whole_loss, predict_all, label_all
+		else:
+			print_num = 6
+			auc_score, f1score, prauc_score, precision, recall, accuracy, \
+			predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold)
+			print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \
+				 + "\nPR-AUC: " + str(prauc_score)[:print_num] \
+				 + "\nPrecision: " + str(precision)[:print_num] \
+				 + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \
+				 + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \
+				 + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num])
+			return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
+
+	def learn(self, train_loader, valid_loader, test_loader):
+		opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
+		train_loss_record = [] 
+		valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True)
+		valid_loss_record = [valid_loss]
+		best_valid_loss = valid_loss
+		best_model = deepcopy(self)
+		train_output = []
+		valid_output = []
+		for ep in tqdm(range(self.epoch)):
+			for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader:
+				label_vec = label_vec.to(self.device)
+				output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1)  #### 32, 1 -> 32, ||  label_vec 32,
+				loss = self.loss(output, label_vec.float())
+				train_loss_record.append(loss.item())
+				train_output.append((loss.item(), output, label_vec))
+				opt.zero_grad()
+				loss.backward()
+				opt.step()
+			valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True)
+			valid_loss_record.append(valid_loss)
+			valid_output.append((valid_loss, valid_predict, valid_label))
+
+			print(f"valid_loss: {valid_loss}")
+			print(best_valid_loss)
+			if valid_loss < best_valid_loss:
+				best_valid_loss = valid_loss 
+				best_model = deepcopy(self)
+
+		self.plot_learning_curve(train_loss_record, valid_loss_record)
+		self = deepcopy(best_model)
+		auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader)
+		return train_output, valid_output
+
+	def plot_learning_curve(self, train_loss_record, valid_loss_record):
+		plt.plot(train_loss_record)
+		plt.savefig("./figure/" + self.save_name + '_train_loss.jpg')
+		plt.clf() 
+		plt.plot(valid_loss_record)
+		plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg')
+		plt.clf() 
+
+	def select_threshold_for_binary(self, validloader):
+		_, prediction, label_all, nctid_all = self.generate_predict(validloader)
+		best_f1 = 0
+		for threshold in prediction:
+			float2binary = lambda x:0 if x<threshold else 1
+			predict_all = list(map(float2binary, prediction))
+			f1score = precision_score(label_all, predict_all)        
+			if f1score > best_f1:
+				best_f1 = f1score 
+				best_threshold = threshold
+		return best_threshold 
+
+
+class HINTModel_multi(Interaction):
+
+	def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
+					device, 
+					global_embed_size, 
+					highway_num_layer,
+					prefix_name, 
+					epoch = 20,
+					lr = 3e-4, 
+					weight_decay = 0, 
+					):
+		super(HINTModel_multi, self).__init__(molecule_encoder = molecule_encoder, 
+								   disease_encoder = disease_encoder, 
+								   protocol_encoder = protocol_encoder, 
+								   device = device, 
+								   prefix_name = prefix_name, 
+								   global_embed_size = global_embed_size, 
+								   highway_num_layer = highway_num_layer,
+								   epoch = epoch,
+								   lr = lr, 
+								   weight_decay = weight_decay)
+		self.pred_nn = nn.Linear(self.global_embed_size, 4)
+		self.loss = nn.CrossEntropyLoss()
+
+	def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
+		molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
+		interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
+		output = self.pred_nn(interaction_embedding)
+		return output ### 32, 4
+
+	def generate_predict(self, dataloader):
+		whole_loss = 0 
+		label_all, predict_all = [], []
+		for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader:
+			label_vec = label_vec.to(self.device)
+			output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst) 
+			loss = self.loss(output, label_vec)
+			whole_loss += loss.item()
+			predict_all.extend(torch.argmax(output, 1).tolist())
+			# predict_all.extend([i.item() for i in torch.sigmoid(output)])
+			label_all.extend([i.item() for i in label_vec])
+
+		accuracy = len(list(filter(lambda x:x[0]==x[1], zip(predict_all, label_all)))) / len(label_all)
+		return whole_loss, predict_all, label_all, accuracy
+
+	def test(self, dataloader, return_loss = True, validloader=None):
+		# if validloader is not None:
+		# 	best_threshold = self.select_threshold_for_binary(validloader)
+		self.eval()
+		whole_loss, predict_all, label_all, accuracy = self.generate_predict(dataloader)
+		self.train()
+		return whole_loss, predict_all, label_all, accuracy
+		# # from HINT.utils import plot_hist
+		# # plt.clf()
+		# # prefix_name = "./figure/" + self.save_name 
+		# # plot_hist(prefix_name, predict_all, label_all)
+		# self.train()
+		# if return_loss:
+		# 	return whole_loss
+		# else:
+		# 	print_num = 5
+		# 	auc_score, f1score, prauc_score, precision, recall, accuracy, \
+		# 	predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold)
+		# 	print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \
+		# 		 + "\nPR-AUC: " + str(prauc_score)[:print_num] \
+		# 		 + "\nPrecision: " + str(precision)[:print_num] \
+		# 		 + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \
+		# 		 + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \
+		# 		 + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num])
+		# 	return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
+
+	def learn(self, train_loader, valid_loader, test_loader):
+		opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
+		train_loss_record = []
+		valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True)
+		print('accuracy', accuracy)
+		# valid_loss_record = [valid_loss]
+		# best_valid_loss = valid_loss
+		best_model = deepcopy(self)
+		for ep in tqdm(range(self.epoch)):
+			self.train() 
+			for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader:
+				label_vec = label_vec.to(self.device)
+				output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst)  #### 32, 1 -> 32, ||  label_vec 32,
+				# print(label_vec.shape, output.shape, label_vec, output)
+				loss = self.loss(output, label_vec)
+				train_loss_record.append(loss.item())
+				opt.zero_grad() 
+				loss.backward() 
+				opt.step()
+			valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True)
+			print('accuracy', accuracy)
+		return predict_all, label_all
+		# 	valid_loss_record.append(valid_loss)
+		# 	if valid_loss < best_valid_loss:
+		# 		best_valid_loss = valid_loss 
+		# 		best_model = deepcopy(self)
+
+		# self.plot_learning_curve(train_loss_record, valid_loss_record)
+		# self = deepcopy(best_model)
+		# auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader)
+
+
+class HINT_nograph(Interaction):
+	def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, device, 
+					global_embed_size, 
+					highway_num_layer,
+					prefix_name, 
+					epoch = 20,
+					lr = 3e-4, 
+					weight_decay = 0, ):
+		super(HINT_nograph, self).__init__(molecule_encoder = molecule_encoder, 
+								   disease_encoder = disease_encoder, 
+								   protocol_encoder = protocol_encoder,
+								   device = device,  
+								   global_embed_size = global_embed_size, 
+								   prefix_name = prefix_name, 
+								   highway_num_layer = highway_num_layer,
+								   epoch = epoch,
+								   lr = lr, 
+								   weight_decay = weight_decay, 
+								   ) 
+		self.save_name = prefix_name + '_HINT_nograph'
+		'''	### interaction model 
+		self.molecule_encoder = molecule_encoder 
+		self.disease_encoder = disease_encoder 
+		self.protocol_encoder = protocol_encoder 
+		self.global_embed_size = global_embed_size 
+		self.highway_num_layer = highway_num_layer 
+		self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size
+		self.epoch = epoch 
+		self.lr = lr 
+		self.weight_decay = weight_decay 
+		self.save_name = save_name
+
+		self.f = F.relu
+		self.loss = nn.BCEWithLogitsLoss()
+
+		##### NN 
+		self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size)
+		self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer)
+		self.pred_nn = nn.Linear(self.global_embed_size, 1)
+		'''
+
+		#### risk of disease 
+		self.risk_disease_fc = nn.Linear(self.disease_encoder.embedding_size, self.global_embed_size)
+		self.risk_disease_higway = Highway(self.global_embed_size, self.highway_num_layer)
+
+		#### augment interaction 
+		self.augment_interaction_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size)
+		self.augment_interaction_highway = Highway(self.global_embed_size, self.highway_num_layer)
+
+		#### ADMET 
+		self.admet_model = []
+		for i in range(5):
+			admet_fc = nn.Linear(self.molecule_encoder.embedding_size, self.global_embed_size).to(device)
+			admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
+			self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) 
+		self.admet_model = nn.ModuleList(self.admet_model)
+
+		#### PK 
+		self.pk_fc = nn.Linear(self.global_embed_size*5, self.global_embed_size)
+		self.pk_highway = Highway(self.global_embed_size, self.highway_num_layer)
+
+		#### trial node 
+		self.trial_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size)
+		self.trial_highway = Highway(self.global_embed_size, self.highway_num_layer)
+
+		## self.pred_nn = nn.Linear(self.global_embed_size, 1)
+
+		self.device = device 
+		self = self.to(device)
+
+	def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = False):
+		### encoder for molecule, disease and protocol
+		molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst)
+		### interaction 
+		interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed)
+		### risk of disease 
+		risk_of_disease_embedding = self.feed_lst_of_module(input_feature = icd_embed, 
+															lst_of_module = [self.risk_disease_fc, self.risk_disease_higway])
+		### augment interaction   
+		augment_interaction_input = torch.cat([interaction_embedding, risk_of_disease_embedding], 1)
+		augment_interaction_embedding = self.feed_lst_of_module(input_feature = augment_interaction_input, 
+																lst_of_module = [self.augment_interaction_fc, self.augment_interaction_highway])
+		### admet 
+		admet_embedding_lst = []
+		for idx in range(5):
+			admet_embedding = self.feed_lst_of_module(input_feature = molecule_embed, 
+													  lst_of_module = self.admet_model[idx])
+			admet_embedding_lst.append(admet_embedding)
+		### pk 
+		pk_input = torch.cat(admet_embedding_lst, 1)
+		pk_embedding = self.feed_lst_of_module(input_feature = pk_input, 
+											   lst_of_module = [self.pk_fc, self.pk_highway])
+		### trial 
+		trial_input = torch.cat([pk_embedding, augment_interaction_embedding], 1)
+		trial_embedding = self.feed_lst_of_module(input_feature = trial_input, 
+												  lst_of_module = [self.trial_fc, self.trial_highway])
+		output = self.pred_nn(trial_embedding)
+		if if_gnn == False:
+			return output 
+		else:
+			embedding_lst = [molecule_embed, icd_embed, protocol_embed, interaction_embedding, risk_of_disease_embedding, \
+							 augment_interaction_embedding] + admet_embedding_lst + [pk_embedding, trial_embedding]
+			return embedding_lst
+
+
+class HINTModel(HINT_nograph):
+
+	def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
+					device, 
+					global_embed_size, 
+					highway_num_layer,
+					prefix_name, 
+					gnn_hidden_size, 
+					epoch = 20,
+					lr = 3e-4, 
+					weight_decay = 0,):
+		super(HINTModel, self).__init__(molecule_encoder = molecule_encoder, 
+								   disease_encoder = disease_encoder, 
+								   protocol_encoder = protocol_encoder, 
+								   device = device, 
+								   prefix_name = prefix_name, 
+								   global_embed_size = global_embed_size, 
+								   highway_num_layer = highway_num_layer,
+								   epoch = epoch,
+								   lr = lr, 
+								   weight_decay = weight_decay)
+		self.save_name = prefix_name 
+		self.gnn_hidden_size = gnn_hidden_size 
+		#### GNN 
+		self.adj = self.generate_adj()          
+		self.gnn = GCN(
+            nfeat = self.global_embed_size,
+            nhid = self.gnn_hidden_size,
+            nclass = 1,
+            dropout = 0.6,
+		    init = 'uniform') 
+		### gnn's attention 		
+		self.node_size = self.adj.shape[0]
+		'''
+		self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() \
+        								if self.adj[i,j]==1 else None  \
+        								for j in range(self.node_size)]) \
+        								for i in range(self.node_size)])
+        '''
+		self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)])
+		# self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)])
+
+		'''
+nn.ModuleList([ nn.ModuleList([nn.Linear(3,2) for j in range(5)] + [None]) for i in range(3)])
+		'''
+
+		self.device = device 
+		self = self.to(device)
+
+	def generate_adj(self):        								
+		##### consistent with HINT_nograph.forward
+		lst = ["molecule", "disease", "criteria", 'INTERACTION', 'risk_disease', 'augment_interaction', 'A', 'D', 'M', 'E', 'T', 'PK', "final"]
+		edge_lst = [("disease", "molecule"), ("disease", "criteria"), ("molecule", "criteria"), 
+					("disease", "INTERACTION"), ("molecule", "INTERACTION"),  ("criteria", "INTERACTION"), 
+					("disease", "risk_disease"), ('risk_disease', 'augment_interaction'), ('INTERACTION', 'augment_interaction'),
+					("molecule", "A"), ("molecule", "D"), ("molecule", "M"), ("molecule", "E"), ("molecule", "T"),
+					('A', 'PK'), ('D', 'PK'), ('M', 'PK'), ('E', 'PK'), ('T', 'PK'), 
+					('augment_interaction', 'final'), ('PK', 'final')]
+		adj = torch.zeros(len(lst), len(lst))
+		adj = torch.eye(len(lst)) * len(lst)
+		num2str = {k:v for k,v in enumerate(lst)}
+		str2num = {v:k for k,v in enumerate(lst)}
+		for i,j in edge_lst:
+			n1,n2 = str2num[i], str2num[j]
+			adj[n1,n2] = 1
+			adj[n2,n1] = 1
+		return adj.to(self.device) 
+
+	def generate_attention_matrx(self, node_feature_mat):
+		attention_mat = torch.zeros(self.node_size, self.node_size).to(self.device)
+		for i in range(self.node_size):
+			for j in range(self.node_size):
+				if self.adj[i,j]!=1:
+					continue 
+				feature = torch.cat([node_feature_mat[i].view(1,-1), node_feature_mat[j].view(1,-1)], 1)
+				attention_model = self.graph_attention_model_mat[i][j]
+				attention_mat[i,j] = torch.sigmoid(self.feed_lst_of_module(input_feature=feature, lst_of_module=attention_model))
+		return attention_mat 
+
+	##### self.global_embed_size*2 -> 1 
+	def gnn_attention(self):
+		highway_nn = Highway(size = self.global_embed_size*2, num_layers = self.highway_num_layer).to(self.device)
+		highway_fc = nn.Linear(self.global_embed_size*2, 1).to(self.device)
+		return nn.ModuleList([highway_nn, highway_fc])	
+
+	def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix = False):
+		embedding_lst = HINT_nograph.forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = True)
+		### length is 13, each is 32,50 
+		batch_size = embedding_lst[0].shape[0]
+		output_lst = []
+		if return_attention_matrix:
+			attention_mat_lst = []
+		for i in range(batch_size):
+			node_feature_lst = [embedding[i].view(1,-1) for embedding in embedding_lst]
+			node_feature_mat = torch.cat(node_feature_lst, 0) ### 13, 50 
+			attention_mat = self.generate_attention_matrx(node_feature_mat)
+			output = self.gnn(node_feature_mat, self.adj * attention_mat)
+			output = output[-1].view(1,-1)
+			output_lst.append(output)
+			if return_attention_matrix:
+				attention_mat_lst.append(attention_mat)
+		output_mat = torch.cat(output_lst, 0)
+		if not return_attention_matrix:
+			return output_mat 
+		else:
+			return output_mat, attention_mat_lst
+
+	def interpret(self, complete_dataloader):
+		from graph_visualize_interpret import data2graph 
+		from HINT.utils import replace_strange_symbol
+		for nctid_lst, status_lst, why_stop_lst, label_vec, phase_lst, \
+			diseases_lst, icdcode_lst3, drugs_lst, smiles_lst2, criteria_lst in complete_dataloader: 
+			output, attention_mat_lst = self.forward(smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix=True)
+			output = output.view(-1)
+			batch_size = len(nctid_lst)
+			for i in range(batch_size):
+				name = '__'.join([nctid_lst[i], status_lst[i], why_stop_lst[i], \
+														str(label_vec[i].item()), str(torch.sigmoid(output[i]).item())[:5], \
+														phase_lst[i], diseases_lst[i], drugs_lst[i]])
+				if len(name) > 150:
+					name = name[:250]
+				name = replace_strange_symbol(name)
+				name = name.replace('__', '_')
+				name = name.replace('  ', ' ')
+				name = 'interpret_result/' + name + '.png'
+				print(name)
+				data2graph(attention_matrix = attention_mat_lst[i], adj = self.adj, save_name = name)
+
+	def init_pretrain(self, admet_model):
+		self.molecule_encoder = admet_model.molecule_encoder
+
+	### generate attention matrix 
+
+
+class Only_Molecule(Interaction):
+
+	def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
+					global_embed_size, 
+					highway_num_layer,
+					prefix_name, 
+					epoch = 20,
+					lr = 3e-4, 
+					weight_decay = 0):
+		super(Only_Molecule, self).__init__(molecule_encoder=molecule_encoder, 
+											disease_encoder=disease_encoder, 
+											protocol_encoder=protocol_encoder, 
+											global_embed_size = global_embed_size, 
+											highway_num_layer = highway_num_layer,
+											prefix_name = prefix_name, 
+											epoch = epoch,
+											lr = lr, 
+											weight_decay = weight_decay,)
+		self.molecule2out = nn.Linear(self.global_embed_size,1)
+
+
+	def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
+		molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2)
+		return self.molecule2out(molecule_embed)
+
+class Only_Disease(Only_Molecule):
+
+	def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, 
+					global_embed_size, 
+					highway_num_layer,
+					prefix_name, 
+					epoch = 20,
+					lr = 3e-4, 
+					weight_decay = 0):
+		super(Only_Disease, self).__init__(molecule_encoder = molecule_encoder, 
+											disease_encoder=disease_encoder, 
+											protocol_encoder=protocol_encoder, 
+											global_embed_size = global_embed_size, 
+											highway_num_layer = highway_num_layer,
+											prefix_name = prefix_name, 
+											epoch = epoch,
+											lr = lr, 
+											weight_decay = weight_decay,)
+		self.disease2out = self.molecule2out 
+
+
+	def forward(self, smiles_lst2, icdcode_lst3, criteria_lst):
+		icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3)
+		return self.disease2out(icd_embed)
+
+def dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, global_icd):
+	## label_vec: (n,)
+	y = label_vec 
+
+	num_icd = len(global_icd)
+	from HINT.utils import smiles_lst2fp 
+	fp_lst = [smiles_lst2fp(smiles_lst).reshape(1,-1) for smiles_lst in smiles_lst2]
+	fp_mat = np.concatenate(fp_lst, 0)
+	# fp_mat = torch.from_numpy(fp_mat)  ### (n,2048)
+
+	icdcode_lst = []
+	for lst2 in icdcode_lst3:
+		lst = list(reduce(lambda x,y:x+y, lst2))
+		lst = [i.split('.')[0] for i in lst]
+		lst = set(lst)	
+		icd_feature = np.zeros((1,num_icd), np.int32)
+		for ele in lst:
+			if ele in global_icd:
+				idx = global_icd.index(ele)
+				icd_feature[0,idx] = 1 
+		icdcode_lst.append(icd_feature)
+	icdcode_mat = np.concatenate(icdcode_lst, 0)
+	X = np.concatenate([fp_mat, icdcode_mat], 1)
+	X = torch.from_numpy(X)
+	X = X.float()
+	# icdcode_mat = torch.from_numpy(icdcode_mat) 
+
+	# X = torch.cat([fp_mat, icdcode_mat], 1)
+	return X, y 
+
+
+class FFNN(nn.Sequential):
+	def __init__(self, molecule_dim, diseasecode_dim, 
+					global_icd, 
+					protocol_dim = 0,
+					prefix_name = 'FFNN', 
+					epoch = 10,
+					lr = 3e-4, 
+					weight_decay = 0, 
+					):
+		super(FFNN, self).__init__()
+		self.molecule_dim = molecule_dim 
+		self.diseasecode_dim = diseasecode_dim 
+		self.protocol_dim = protocol_dim 
+		self.prefix_name = prefix_name 
+		self.epoch = epoch 
+		self.lr = lr 
+		self.weight_decay = weight_decay 
+		self.global_icd = global_icd 
+		self.num_icd = len(global_icd)
+
+		self.fc_dims = [self.molecule_dim + self.diseasecode_dim + self.protocol_dim, 2000, 1000, 200, 50, 1]
+		self.fc_layers = nn.ModuleList([nn.Linear(v,self.fc_dims[i+1]) for i,v in enumerate(self.fc_dims[:-1])])
+		self.loss = nn.BCEWithLogitsLoss()
+		self.save_name = prefix_name 
+
+	def forward(self, X):
+		for i in range(len(self.fc_layers) - 1):
+			fc_layer = self.fc_layers[i]
+			X = fc_layer(X)
+		last_layer = self.fc_layers[-1]
+		pred = F.sigmoid(last_layer(X))
+		return pred 
+
+	def learn(self, train_loader, valid_loader, test_loader):
+		opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
+		train_loss_record = [] 
+		valid_loss = self.test(valid_loader, return_loss=True)
+		valid_loss_record = [valid_loss]
+		best_valid_loss = valid_loss
+		best_model = deepcopy(self)
+
+		for ep in tqdm(range(self.epoch)):
+			for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader:
+				X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd)
+				output = self.forward(X).view(-1)  #### 32, 1 -> 32, ||  label_vec 32,
+				loss = self.loss(output, label_vec.float())
+				train_loss_record.append(loss.item())
+				opt.zero_grad() 
+				loss.backward() 
+				opt.step()
+			valid_loss = self.test(valid_loader, return_loss=True)
+			valid_loss_record.append(valid_loss)
+			if valid_loss < best_valid_loss:
+				best_valid_loss = valid_loss 
+				best_model = deepcopy(self)
+
+		self.plot_learning_curve(train_loss_record, valid_loss_record)
+		self = deepcopy(best_model)
+		auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader)
+
+	def evaluation(self, predict_all, label_all, threshold = 0.5):
+		import pickle, os
+		from sklearn.metrics import roc_curve, precision_recall_curve
+		with open("predict_label.txt", 'w') as fout:
+			for i,j in zip(predict_all, label_all):
+				fout.write(str(i)[:4] + '\t' + str(j)[:4]+'\n')
+		auc_score = roc_auc_score(label_all, predict_all)
+		figure_folder = "figure"
+		#### ROC-curve 
+		fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1)
+		# roc_curve =plt.figure()
+		# plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ')
+		# plt.legend(fontsize = 15)
+		#plt.savefig(os.path.join(figure_folder,name+"_roc_curve.png"))
+		#### PR-curve
+		precision, recall, thresholds = precision_recall_curve(label_all, predict_all)
+		# plt.plot(recall,precision, label = self.save_name + ' PR Curve')
+		# plt.legend(fontsize = 15)
+		# plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png"))
+		label_all = [int(i) for i in label_all]
+		float2binary = lambda x:0 if x<threshold else 1
+		predict_all = list(map(float2binary, predict_all))
+		f1score = f1_score(label_all, predict_all)
+		prauc_score = average_precision_score(label_all, predict_all)
+		# print(predict_all)
+		precision = precision_score(label_all, predict_all)
+		recall = recall_score(label_all, predict_all)
+		accuracy = accuracy_score(label_all, predict_all)
+		predict_1_ratio = sum(predict_all) / len(predict_all)
+		label_1_ratio = sum(label_all) / len(label_all)
+		return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
+
+	def generate_predict(self, dataloader):
+		whole_loss = 0 
+		label_all, predict_all = [], []
+		for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader:
+			X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) 
+			output = self.forward(X).view(-1)  
+			loss = self.loss(output, label_vec.float())
+			whole_loss += loss.item()
+			predict_all.extend([i.item() for i in torch.sigmoid(output)])
+			label_all.extend([i.item() for i in label_vec])
+
+		return whole_loss, predict_all, label_all
+
+	def bootstrap_test(self, dataloader, validloader = None, sample_num = 20):
+		best_threshold = 0.5
+		# if validloader is not None:
+		# 	best_threshold = self.select_threshold_for_binary(validloader)
+		self.eval()
+		whole_loss, predict_all, label_all = self.generate_predict(dataloader)
+		from HINT.utils import plot_hist
+		plt.clf()
+		prefix_name = "./figure/" + self.save_name 
+		plot_hist(prefix_name, predict_all, label_all)		
+		def bootstrap(length, sample_num):
+			idx = [i for i in range(length)]
+			from random import choices 
+			bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)]
+			return bootstrap_idx 
+		results_lst = []
+		bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num)
+		for bootstrap_idx in bootstrap_idx_lst: 
+			bootstrap_label = [label_all[idx] for idx in bootstrap_idx]		
+			bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx]
+			results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold)
+			results_lst.append(results)
+		self.train() 
+		auc = [results[0] for results in results_lst]
+		f1score = [results[1] for results in results_lst]
+		prauc_score = [results[2] for results in results_lst]
+		print("PR-AUC   mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6])
+		print("F1       mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6])
+		print("ROC-AUC  mean: "+ str(np.mean(auc))[:6], "std: " + str(np.std(auc))[:6])
+
+	def test(self, dataloader, return_loss = True, validloader=None):
+		# if validloader is not None:
+		# 	best_threshold = self.select_threshold_for_binary(validloader)
+		self.eval()
+		best_threshold = 0.5 
+		whole_loss, predict_all, label_all = self.generate_predict(dataloader)
+		# from HINT.utils import plot_hist
+		# plt.clf()
+		# prefix_name = "./figure/" + self.save_name 
+		# plot_hist(prefix_name, predict_all, label_all)
+		self.train()
+		if return_loss:
+			return whole_loss
+		else:
+			print_num = 5
+			auc_score, f1score, prauc_score, precision, recall, accuracy, \
+			predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold)
+			print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \
+				 + "\nPR-AUC: " + str(prauc_score)[:print_num] \
+				 + "\nPrecision: " + str(precision)[:print_num] \
+				 + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \
+				 + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \
+				 + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num])
+			return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio 
+
+	def plot_learning_curve(self, train_loss_record, valid_loss_record):
+		plt.plot(train_loss_record)
+		plt.savefig("./figure/" + self.save_name + '_train_loss.jpg')
+		plt.clf() 
+		plt.plot(valid_loss_record)
+		plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg')
+		plt.clf() 
+
+
+class ADMET(nn.Sequential):
+	def __init__(self, mpnn_model, device):
+		super(ADMET, self).__init__()
+		self.num = 5 
+		self.mpnn_model = mpnn_model 
+		self.device = device 
+		self.mpnn_dim = mpnn_model.mpnn_hidden_size 
+		self.admet_model = []
+		self.global_embed_size = self.mpnn_dim 
+		self.highway_num_layer = 2 
+		for i in range(5):
+			admet_fc = nn.Linear(self.mpnn_model.mpnn_hidden_size, self.global_embed_size).to(device)
+			admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device)
+			self.admet_model.append(nn.ModuleList([admet_fc, admet_highway]))
+		self.admet_model = nn.ModuleList(self.admet_model)
+
+		self.admet_pred = nn.ModuleList([nn.Linear(self.global_embed_size,1).to(device) for i in range(5)])
+		self.f = F.relu 
+
+		self.device = device 
+		self = self.to(device)
+
+	def feed_lst_of_module(self, input_feature, lst_of_module):
+		x = input_feature
+		for single_module in lst_of_module:
+			x = self.f(single_module(x))
+		return x 
+
+	def forward(self, smiles_lst, idx):
+		assert idx in list(range(5))
+		'''
+			xxxxxxxxxxxx
+		'''
+		embeds = self.mpnn_model.forward_smiles_lst_lst(smiles_lst)
+		embeds = self.feed_lst_of_module(embeds, self.admet_model[idx]) 
+		output = self.admet_pred[idx](embeds)
+		return output 
+
+	def test(self, valid_loader):
+		pass 
+
+	def learn(self, train_loader, valid_loader, idx):
+		opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay)
+		train_loss_record = [] 
+		valid_loss = self.test(valid_loader, return_loss=True)
+		valid_loss_record = [valid_loss]
+		best_valid_loss = valid_loss
+		best_model = deepcopy(self)
+
+		for ep in tqdm(range(self.epoch)):
+			for smiles_lst in train_loader:
+				output = self.forward(smiles_lst).view(-1)  #### 32, 1 -> 32, ||  label_vec 32,
+				loss = self.loss(output, label_vec.float())
+				train_loss_record.append(loss.item())
+				opt.zero_grad() 
+				loss.backward() 
+				opt.step()
+			valid_loss = self.test(valid_loader, return_loss=True)
+			valid_loss_record.append(valid_loss)
+			if valid_loss < best_valid_loss:
+				best_valid_loss = valid_loss 
+				best_model = deepcopy(self)
+
+		self = deepcopy(best_model)