|
a |
|
b/HINT/model.py |
|
|
1 |
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
from copy import deepcopy |
|
|
4 |
import numpy as np |
|
|
5 |
from tqdm import tqdm |
|
|
6 |
import torch |
|
|
7 |
torch.manual_seed(0) |
|
|
8 |
from torch import nn |
|
|
9 |
from torch.autograd import Variable |
|
|
10 |
import torch.nn.functional as F |
|
|
11 |
from HINT.module import Highway, GCN |
|
|
12 |
from functools import reduce |
|
|
13 |
import pickle |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class Interaction(nn.Sequential): |
|
|
17 |
def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, |
|
|
18 |
device, |
|
|
19 |
global_embed_size, |
|
|
20 |
highway_num_layer, |
|
|
21 |
prefix_name, |
|
|
22 |
epoch = 20, |
|
|
23 |
lr = 3e-4, |
|
|
24 |
weight_decay = 0, |
|
|
25 |
): |
|
|
26 |
super(Interaction, self).__init__() |
|
|
27 |
self.molecule_encoder = molecule_encoder |
|
|
28 |
self.disease_encoder = disease_encoder |
|
|
29 |
self.protocol_encoder = protocol_encoder |
|
|
30 |
self.global_embed_size = global_embed_size |
|
|
31 |
self.highway_num_layer = highway_num_layer |
|
|
32 |
self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size |
|
|
33 |
self.epoch = epoch |
|
|
34 |
self.lr = lr |
|
|
35 |
self.weight_decay = weight_decay |
|
|
36 |
self.save_name = prefix_name + '_interaction' |
|
|
37 |
|
|
|
38 |
self.f = F.relu |
|
|
39 |
self.loss = nn.BCEWithLogitsLoss() |
|
|
40 |
|
|
|
41 |
##### NN |
|
|
42 |
self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size).to(device) |
|
|
43 |
self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device) |
|
|
44 |
self.pred_nn = nn.Linear(self.global_embed_size, 1) |
|
|
45 |
|
|
|
46 |
self.device = device |
|
|
47 |
self = self.to(device) |
|
|
48 |
|
|
|
49 |
def feed_lst_of_module(self, input_feature, lst_of_module): |
|
|
50 |
x = input_feature |
|
|
51 |
for single_module in lst_of_module: |
|
|
52 |
x = self.f(single_module(x)) |
|
|
53 |
return x |
|
|
54 |
|
|
|
55 |
def forward_get_three_encoders(self, smiles_lst2, icdcode_lst3, criteria_lst): |
|
|
56 |
molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2) |
|
|
57 |
icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3) |
|
|
58 |
protocol_embed = self.protocol_encoder.forward(criteria_lst) |
|
|
59 |
return molecule_embed, icd_embed, protocol_embed |
|
|
60 |
|
|
|
61 |
def forward_encoder_2_interaction(self, molecule_embed, icd_embed, protocol_embed): |
|
|
62 |
encoder_embedding = torch.cat([molecule_embed, icd_embed, protocol_embed], 1) |
|
|
63 |
# interaction_embedding = self.feed_lst_of_module(encoder_embedding, [self.encoder2interaction_fc, self.encoder2interaction_highway]) |
|
|
64 |
h = self.encoder2interaction_fc(encoder_embedding) |
|
|
65 |
h = self.f(h) |
|
|
66 |
h = self.encoder2interaction_highway(h) |
|
|
67 |
interaction_embedding = self.f(h) |
|
|
68 |
return interaction_embedding |
|
|
69 |
|
|
|
70 |
def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): |
|
|
71 |
molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst) |
|
|
72 |
interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed) |
|
|
73 |
output = self.pred_nn(interaction_embedding) |
|
|
74 |
return output ### 32, 1 |
|
|
75 |
|
|
|
76 |
def evaluation(self, predict_all, label_all, threshold = 0.5): |
|
|
77 |
import pickle, os |
|
|
78 |
from sklearn.metrics import roc_curve, precision_recall_curve |
|
|
79 |
with open("predict_label.txt", 'w') as fout: |
|
|
80 |
for i,j in zip(predict_all, label_all): |
|
|
81 |
fout.write(str(i)[:6] + '\t' + str(j)[:4]+'\n') |
|
|
82 |
auc_score = roc_auc_score(label_all, predict_all) |
|
|
83 |
figure_folder = "figure" |
|
|
84 |
#### ROC-curve |
|
|
85 |
fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1) |
|
|
86 |
# roc_curve =plt.figure() |
|
|
87 |
# plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ') |
|
|
88 |
# plt.legend(fontsize = 15) |
|
|
89 |
# plt.savefig(os.path.join(figure_folder,self.save_name+"_roc_curve.png")) |
|
|
90 |
#### PR-curve |
|
|
91 |
precision, recall, thresholds = precision_recall_curve(label_all, predict_all) |
|
|
92 |
# plt.plot(recall,precision, label = self.save_name + ' PR Curve') |
|
|
93 |
# plt.legend(fontsize = 15) |
|
|
94 |
# plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png")) |
|
|
95 |
label_all = [int(i) for i in label_all] |
|
|
96 |
float2binary = lambda x:0 if x < threshold else 1 |
|
|
97 |
predict_all = list(map(float2binary, predict_all)) |
|
|
98 |
f1score = f1_score(label_all, predict_all) |
|
|
99 |
prauc_score = average_precision_score(label_all, predict_all) |
|
|
100 |
# print(predict_all) |
|
|
101 |
precision = precision_score(label_all, predict_all) |
|
|
102 |
recall = recall_score(label_all, predict_all) |
|
|
103 |
accuracy = accuracy_score(label_all, predict_all) |
|
|
104 |
predict_1_ratio = sum(predict_all) / len(predict_all) |
|
|
105 |
label_1_ratio = sum(label_all) / len(label_all) |
|
|
106 |
return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio |
|
|
107 |
|
|
|
108 |
def testloader_to_lst(self, dataloader): |
|
|
109 |
nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst = [], [], [], [], [] |
|
|
110 |
for nctid, label, smiles, icdcode, criteria in dataloader: |
|
|
111 |
nctid_lst.extend(nctid) |
|
|
112 |
label_lst.extend([i.item() for i in label]) |
|
|
113 |
smiles_lst2.extend(smiles) |
|
|
114 |
icdcode_lst3.extend(icdcode) |
|
|
115 |
criteria_lst.extend(criteria) |
|
|
116 |
length = len(nctid_lst) |
|
|
117 |
assert length == len(smiles_lst2) and length == len(icdcode_lst3) |
|
|
118 |
return nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst, length |
|
|
119 |
|
|
|
120 |
def generate_predict(self, dataloader): |
|
|
121 |
whole_loss = 0 |
|
|
122 |
label_all, predict_all, nctid_all = [], [], [] |
|
|
123 |
for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader: |
|
|
124 |
nctid_all.extend(nctid_lst) |
|
|
125 |
label_vec = label_vec.to(self.device) |
|
|
126 |
output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1) |
|
|
127 |
loss = self.loss(output, label_vec.float()) |
|
|
128 |
whole_loss += loss.item() |
|
|
129 |
predict_all.extend([i.item() for i in torch.sigmoid(output)]) |
|
|
130 |
label_all.extend([i.item() for i in label_vec]) |
|
|
131 |
|
|
|
132 |
return whole_loss, predict_all, label_all, nctid_all |
|
|
133 |
|
|
|
134 |
def bootstrap_test(self, dataloader, valid_loader = None, sample_num = 20): |
|
|
135 |
best_threshold = 0.5 |
|
|
136 |
# if validloader is not None: |
|
|
137 |
# best_threshold = self.select_threshold_for_binary(valid_loader) |
|
|
138 |
# print(f"best_threshold: {best_threshold}") |
|
|
139 |
self.eval() |
|
|
140 |
whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) |
|
|
141 |
from HINT.utils import plot_hist |
|
|
142 |
plt.clf() |
|
|
143 |
prefix_name = "./figure/" + self.save_name |
|
|
144 |
plot_hist(prefix_name, predict_all, label_all) |
|
|
145 |
def bootstrap(length, sample_num): |
|
|
146 |
idx = [i for i in range(length)] |
|
|
147 |
from random import choices |
|
|
148 |
bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)] |
|
|
149 |
return bootstrap_idx |
|
|
150 |
results_lst = [] |
|
|
151 |
bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num) |
|
|
152 |
for bootstrap_idx in bootstrap_idx_lst: |
|
|
153 |
bootstrap_label = [label_all[idx] for idx in bootstrap_idx] |
|
|
154 |
bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx] |
|
|
155 |
results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold) |
|
|
156 |
results_lst.append(results) |
|
|
157 |
self.train() |
|
|
158 |
auc = [results[0] for results in results_lst] |
|
|
159 |
f1score = [results[1] for results in results_lst] |
|
|
160 |
prauc_score = [results[2] for results in results_lst] |
|
|
161 |
print("PR-AUC mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6]) |
|
|
162 |
print("F1 mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6]) |
|
|
163 |
print("ROC-AUC mean: "+str(np.mean(auc))[:6], "std: "+str(np.std(auc))[:6]) |
|
|
164 |
|
|
|
165 |
for nctid, label, predict in zip(nctid_all, label_all, predict_all): |
|
|
166 |
if (predict > 0.5 and label == 0) or (predict < 0.5 and label == 1): |
|
|
167 |
print(nctid, label, str(predict)[:6]) |
|
|
168 |
|
|
|
169 |
nctid2predict = {nctid:predict for nctid, predict in zip(nctid_all, predict_all)} |
|
|
170 |
pickle.dump(nctid2predict, open('results/nctid2predict.pkl', 'wb')) |
|
|
171 |
return nctid_all, predict_all |
|
|
172 |
|
|
|
173 |
def ongoing_test(self, dataloader, sample_num = 20): |
|
|
174 |
self.eval() |
|
|
175 |
best_threshold = 0.5 |
|
|
176 |
whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) |
|
|
177 |
self.train() |
|
|
178 |
return nctid_all, predict_all |
|
|
179 |
|
|
|
180 |
def test(self, dataloader, return_loss = True, validloader=None): |
|
|
181 |
# if validloader is not None: |
|
|
182 |
# best_threshold = self.select_threshold_for_binary(validloader) |
|
|
183 |
self.eval() |
|
|
184 |
best_threshold = 0.5 |
|
|
185 |
whole_loss, predict_all, label_all, nctid_all = self.generate_predict(dataloader) |
|
|
186 |
# from HINT.utils import plot_hist |
|
|
187 |
# plt.clf() |
|
|
188 |
# prefix_name = "./figure/" + self.save_name |
|
|
189 |
# plot_hist(prefix_name, predict_all, label_all) |
|
|
190 |
self.train() |
|
|
191 |
if return_loss: |
|
|
192 |
return whole_loss, predict_all, label_all |
|
|
193 |
else: |
|
|
194 |
print_num = 6 |
|
|
195 |
auc_score, f1score, prauc_score, precision, recall, accuracy, \ |
|
|
196 |
predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold) |
|
|
197 |
print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \ |
|
|
198 |
+ "\nPR-AUC: " + str(prauc_score)[:print_num] \ |
|
|
199 |
+ "\nPrecision: " + str(precision)[:print_num] \ |
|
|
200 |
+ "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \ |
|
|
201 |
+ "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \ |
|
|
202 |
+ "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num]) |
|
|
203 |
return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio |
|
|
204 |
|
|
|
205 |
def learn(self, train_loader, valid_loader, test_loader): |
|
|
206 |
opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) |
|
|
207 |
train_loss_record = [] |
|
|
208 |
valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True) |
|
|
209 |
valid_loss_record = [valid_loss] |
|
|
210 |
best_valid_loss = valid_loss |
|
|
211 |
best_model = deepcopy(self) |
|
|
212 |
train_output = [] |
|
|
213 |
valid_output = [] |
|
|
214 |
for ep in tqdm(range(self.epoch)): |
|
|
215 |
for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader: |
|
|
216 |
label_vec = label_vec.to(self.device) |
|
|
217 |
output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst).view(-1) #### 32, 1 -> 32, || label_vec 32, |
|
|
218 |
loss = self.loss(output, label_vec.float()) |
|
|
219 |
train_loss_record.append(loss.item()) |
|
|
220 |
train_output.append((loss.item(), output, label_vec)) |
|
|
221 |
opt.zero_grad() |
|
|
222 |
loss.backward() |
|
|
223 |
opt.step() |
|
|
224 |
valid_loss, valid_predict, valid_label = self.test(valid_loader, return_loss=True) |
|
|
225 |
valid_loss_record.append(valid_loss) |
|
|
226 |
valid_output.append((valid_loss, valid_predict, valid_label)) |
|
|
227 |
|
|
|
228 |
print(f"valid_loss: {valid_loss}") |
|
|
229 |
print(best_valid_loss) |
|
|
230 |
if valid_loss < best_valid_loss: |
|
|
231 |
best_valid_loss = valid_loss |
|
|
232 |
best_model = deepcopy(self) |
|
|
233 |
|
|
|
234 |
self.plot_learning_curve(train_loss_record, valid_loss_record) |
|
|
235 |
self = deepcopy(best_model) |
|
|
236 |
auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader) |
|
|
237 |
return train_output, valid_output |
|
|
238 |
|
|
|
239 |
def plot_learning_curve(self, train_loss_record, valid_loss_record): |
|
|
240 |
plt.plot(train_loss_record) |
|
|
241 |
plt.savefig("./figure/" + self.save_name + '_train_loss.jpg') |
|
|
242 |
plt.clf() |
|
|
243 |
plt.plot(valid_loss_record) |
|
|
244 |
plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg') |
|
|
245 |
plt.clf() |
|
|
246 |
|
|
|
247 |
def select_threshold_for_binary(self, validloader): |
|
|
248 |
_, prediction, label_all, nctid_all = self.generate_predict(validloader) |
|
|
249 |
best_f1 = 0 |
|
|
250 |
for threshold in prediction: |
|
|
251 |
float2binary = lambda x:0 if x<threshold else 1 |
|
|
252 |
predict_all = list(map(float2binary, prediction)) |
|
|
253 |
f1score = precision_score(label_all, predict_all) |
|
|
254 |
if f1score > best_f1: |
|
|
255 |
best_f1 = f1score |
|
|
256 |
best_threshold = threshold |
|
|
257 |
return best_threshold |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
class HINTModel_multi(Interaction): |
|
|
261 |
|
|
|
262 |
def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, |
|
|
263 |
device, |
|
|
264 |
global_embed_size, |
|
|
265 |
highway_num_layer, |
|
|
266 |
prefix_name, |
|
|
267 |
epoch = 20, |
|
|
268 |
lr = 3e-4, |
|
|
269 |
weight_decay = 0, |
|
|
270 |
): |
|
|
271 |
super(HINTModel_multi, self).__init__(molecule_encoder = molecule_encoder, |
|
|
272 |
disease_encoder = disease_encoder, |
|
|
273 |
protocol_encoder = protocol_encoder, |
|
|
274 |
device = device, |
|
|
275 |
prefix_name = prefix_name, |
|
|
276 |
global_embed_size = global_embed_size, |
|
|
277 |
highway_num_layer = highway_num_layer, |
|
|
278 |
epoch = epoch, |
|
|
279 |
lr = lr, |
|
|
280 |
weight_decay = weight_decay) |
|
|
281 |
self.pred_nn = nn.Linear(self.global_embed_size, 4) |
|
|
282 |
self.loss = nn.CrossEntropyLoss() |
|
|
283 |
|
|
|
284 |
def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): |
|
|
285 |
molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst) |
|
|
286 |
interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed) |
|
|
287 |
output = self.pred_nn(interaction_embedding) |
|
|
288 |
return output ### 32, 4 |
|
|
289 |
|
|
|
290 |
def generate_predict(self, dataloader): |
|
|
291 |
whole_loss = 0 |
|
|
292 |
label_all, predict_all = [], [] |
|
|
293 |
for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader: |
|
|
294 |
label_vec = label_vec.to(self.device) |
|
|
295 |
output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst) |
|
|
296 |
loss = self.loss(output, label_vec) |
|
|
297 |
whole_loss += loss.item() |
|
|
298 |
predict_all.extend(torch.argmax(output, 1).tolist()) |
|
|
299 |
# predict_all.extend([i.item() for i in torch.sigmoid(output)]) |
|
|
300 |
label_all.extend([i.item() for i in label_vec]) |
|
|
301 |
|
|
|
302 |
accuracy = len(list(filter(lambda x:x[0]==x[1], zip(predict_all, label_all)))) / len(label_all) |
|
|
303 |
return whole_loss, predict_all, label_all, accuracy |
|
|
304 |
|
|
|
305 |
def test(self, dataloader, return_loss = True, validloader=None): |
|
|
306 |
# if validloader is not None: |
|
|
307 |
# best_threshold = self.select_threshold_for_binary(validloader) |
|
|
308 |
self.eval() |
|
|
309 |
whole_loss, predict_all, label_all, accuracy = self.generate_predict(dataloader) |
|
|
310 |
self.train() |
|
|
311 |
return whole_loss, predict_all, label_all, accuracy |
|
|
312 |
# # from HINT.utils import plot_hist |
|
|
313 |
# # plt.clf() |
|
|
314 |
# # prefix_name = "./figure/" + self.save_name |
|
|
315 |
# # plot_hist(prefix_name, predict_all, label_all) |
|
|
316 |
# self.train() |
|
|
317 |
# if return_loss: |
|
|
318 |
# return whole_loss |
|
|
319 |
# else: |
|
|
320 |
# print_num = 5 |
|
|
321 |
# auc_score, f1score, prauc_score, precision, recall, accuracy, \ |
|
|
322 |
# predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold) |
|
|
323 |
# print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \ |
|
|
324 |
# + "\nPR-AUC: " + str(prauc_score)[:print_num] \ |
|
|
325 |
# + "\nPrecision: " + str(precision)[:print_num] \ |
|
|
326 |
# + "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \ |
|
|
327 |
# + "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \ |
|
|
328 |
# + "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num]) |
|
|
329 |
# return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio |
|
|
330 |
|
|
|
331 |
def learn(self, train_loader, valid_loader, test_loader): |
|
|
332 |
opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) |
|
|
333 |
train_loss_record = [] |
|
|
334 |
valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True) |
|
|
335 |
print('accuracy', accuracy) |
|
|
336 |
# valid_loss_record = [valid_loss] |
|
|
337 |
# best_valid_loss = valid_loss |
|
|
338 |
best_model = deepcopy(self) |
|
|
339 |
for ep in tqdm(range(self.epoch)): |
|
|
340 |
self.train() |
|
|
341 |
for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader: |
|
|
342 |
label_vec = label_vec.to(self.device) |
|
|
343 |
output = self.forward(smiles_lst2, icdcode_lst3, criteria_lst) #### 32, 1 -> 32, || label_vec 32, |
|
|
344 |
# print(label_vec.shape, output.shape, label_vec, output) |
|
|
345 |
loss = self.loss(output, label_vec) |
|
|
346 |
train_loss_record.append(loss.item()) |
|
|
347 |
opt.zero_grad() |
|
|
348 |
loss.backward() |
|
|
349 |
opt.step() |
|
|
350 |
valid_loss, predict_all, label_all, accuracy = self.test(valid_loader, return_loss=True) |
|
|
351 |
print('accuracy', accuracy) |
|
|
352 |
return predict_all, label_all |
|
|
353 |
# valid_loss_record.append(valid_loss) |
|
|
354 |
# if valid_loss < best_valid_loss: |
|
|
355 |
# best_valid_loss = valid_loss |
|
|
356 |
# best_model = deepcopy(self) |
|
|
357 |
|
|
|
358 |
# self.plot_learning_curve(train_loss_record, valid_loss_record) |
|
|
359 |
# self = deepcopy(best_model) |
|
|
360 |
# auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader) |
|
|
361 |
|
|
|
362 |
|
|
|
363 |
class HINT_nograph(Interaction): |
|
|
364 |
def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, device, |
|
|
365 |
global_embed_size, |
|
|
366 |
highway_num_layer, |
|
|
367 |
prefix_name, |
|
|
368 |
epoch = 20, |
|
|
369 |
lr = 3e-4, |
|
|
370 |
weight_decay = 0, ): |
|
|
371 |
super(HINT_nograph, self).__init__(molecule_encoder = molecule_encoder, |
|
|
372 |
disease_encoder = disease_encoder, |
|
|
373 |
protocol_encoder = protocol_encoder, |
|
|
374 |
device = device, |
|
|
375 |
global_embed_size = global_embed_size, |
|
|
376 |
prefix_name = prefix_name, |
|
|
377 |
highway_num_layer = highway_num_layer, |
|
|
378 |
epoch = epoch, |
|
|
379 |
lr = lr, |
|
|
380 |
weight_decay = weight_decay, |
|
|
381 |
) |
|
|
382 |
self.save_name = prefix_name + '_HINT_nograph' |
|
|
383 |
''' ### interaction model |
|
|
384 |
self.molecule_encoder = molecule_encoder |
|
|
385 |
self.disease_encoder = disease_encoder |
|
|
386 |
self.protocol_encoder = protocol_encoder |
|
|
387 |
self.global_embed_size = global_embed_size |
|
|
388 |
self.highway_num_layer = highway_num_layer |
|
|
389 |
self.feature_dim = self.molecule_encoder.embedding_size + self.disease_encoder.embedding_size + self.protocol_encoder.embedding_size |
|
|
390 |
self.epoch = epoch |
|
|
391 |
self.lr = lr |
|
|
392 |
self.weight_decay = weight_decay |
|
|
393 |
self.save_name = save_name |
|
|
394 |
|
|
|
395 |
self.f = F.relu |
|
|
396 |
self.loss = nn.BCEWithLogitsLoss() |
|
|
397 |
|
|
|
398 |
##### NN |
|
|
399 |
self.encoder2interaction_fc = nn.Linear(self.feature_dim, self.global_embed_size) |
|
|
400 |
self.encoder2interaction_highway = Highway(self.global_embed_size, self.highway_num_layer) |
|
|
401 |
self.pred_nn = nn.Linear(self.global_embed_size, 1) |
|
|
402 |
''' |
|
|
403 |
|
|
|
404 |
#### risk of disease |
|
|
405 |
self.risk_disease_fc = nn.Linear(self.disease_encoder.embedding_size, self.global_embed_size) |
|
|
406 |
self.risk_disease_higway = Highway(self.global_embed_size, self.highway_num_layer) |
|
|
407 |
|
|
|
408 |
#### augment interaction |
|
|
409 |
self.augment_interaction_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size) |
|
|
410 |
self.augment_interaction_highway = Highway(self.global_embed_size, self.highway_num_layer) |
|
|
411 |
|
|
|
412 |
#### ADMET |
|
|
413 |
self.admet_model = [] |
|
|
414 |
for i in range(5): |
|
|
415 |
admet_fc = nn.Linear(self.molecule_encoder.embedding_size, self.global_embed_size).to(device) |
|
|
416 |
admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device) |
|
|
417 |
self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) |
|
|
418 |
self.admet_model = nn.ModuleList(self.admet_model) |
|
|
419 |
|
|
|
420 |
#### PK |
|
|
421 |
self.pk_fc = nn.Linear(self.global_embed_size*5, self.global_embed_size) |
|
|
422 |
self.pk_highway = Highway(self.global_embed_size, self.highway_num_layer) |
|
|
423 |
|
|
|
424 |
#### trial node |
|
|
425 |
self.trial_fc = nn.Linear(self.global_embed_size*2, self.global_embed_size) |
|
|
426 |
self.trial_highway = Highway(self.global_embed_size, self.highway_num_layer) |
|
|
427 |
|
|
|
428 |
## self.pred_nn = nn.Linear(self.global_embed_size, 1) |
|
|
429 |
|
|
|
430 |
self.device = device |
|
|
431 |
self = self.to(device) |
|
|
432 |
|
|
|
433 |
def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = False): |
|
|
434 |
### encoder for molecule, disease and protocol |
|
|
435 |
molecule_embed, icd_embed, protocol_embed = self.forward_get_three_encoders(smiles_lst2, icdcode_lst3, criteria_lst) |
|
|
436 |
### interaction |
|
|
437 |
interaction_embedding = self.forward_encoder_2_interaction(molecule_embed, icd_embed, protocol_embed) |
|
|
438 |
### risk of disease |
|
|
439 |
risk_of_disease_embedding = self.feed_lst_of_module(input_feature = icd_embed, |
|
|
440 |
lst_of_module = [self.risk_disease_fc, self.risk_disease_higway]) |
|
|
441 |
### augment interaction |
|
|
442 |
augment_interaction_input = torch.cat([interaction_embedding, risk_of_disease_embedding], 1) |
|
|
443 |
augment_interaction_embedding = self.feed_lst_of_module(input_feature = augment_interaction_input, |
|
|
444 |
lst_of_module = [self.augment_interaction_fc, self.augment_interaction_highway]) |
|
|
445 |
### admet |
|
|
446 |
admet_embedding_lst = [] |
|
|
447 |
for idx in range(5): |
|
|
448 |
admet_embedding = self.feed_lst_of_module(input_feature = molecule_embed, |
|
|
449 |
lst_of_module = self.admet_model[idx]) |
|
|
450 |
admet_embedding_lst.append(admet_embedding) |
|
|
451 |
### pk |
|
|
452 |
pk_input = torch.cat(admet_embedding_lst, 1) |
|
|
453 |
pk_embedding = self.feed_lst_of_module(input_feature = pk_input, |
|
|
454 |
lst_of_module = [self.pk_fc, self.pk_highway]) |
|
|
455 |
### trial |
|
|
456 |
trial_input = torch.cat([pk_embedding, augment_interaction_embedding], 1) |
|
|
457 |
trial_embedding = self.feed_lst_of_module(input_feature = trial_input, |
|
|
458 |
lst_of_module = [self.trial_fc, self.trial_highway]) |
|
|
459 |
output = self.pred_nn(trial_embedding) |
|
|
460 |
if if_gnn == False: |
|
|
461 |
return output |
|
|
462 |
else: |
|
|
463 |
embedding_lst = [molecule_embed, icd_embed, protocol_embed, interaction_embedding, risk_of_disease_embedding, \ |
|
|
464 |
augment_interaction_embedding] + admet_embedding_lst + [pk_embedding, trial_embedding] |
|
|
465 |
return embedding_lst |
|
|
466 |
|
|
|
467 |
|
|
|
468 |
class HINTModel(HINT_nograph): |
|
|
469 |
|
|
|
470 |
def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, |
|
|
471 |
device, |
|
|
472 |
global_embed_size, |
|
|
473 |
highway_num_layer, |
|
|
474 |
prefix_name, |
|
|
475 |
gnn_hidden_size, |
|
|
476 |
epoch = 20, |
|
|
477 |
lr = 3e-4, |
|
|
478 |
weight_decay = 0,): |
|
|
479 |
super(HINTModel, self).__init__(molecule_encoder = molecule_encoder, |
|
|
480 |
disease_encoder = disease_encoder, |
|
|
481 |
protocol_encoder = protocol_encoder, |
|
|
482 |
device = device, |
|
|
483 |
prefix_name = prefix_name, |
|
|
484 |
global_embed_size = global_embed_size, |
|
|
485 |
highway_num_layer = highway_num_layer, |
|
|
486 |
epoch = epoch, |
|
|
487 |
lr = lr, |
|
|
488 |
weight_decay = weight_decay) |
|
|
489 |
self.save_name = prefix_name |
|
|
490 |
self.gnn_hidden_size = gnn_hidden_size |
|
|
491 |
#### GNN |
|
|
492 |
self.adj = self.generate_adj() |
|
|
493 |
self.gnn = GCN( |
|
|
494 |
nfeat = self.global_embed_size, |
|
|
495 |
nhid = self.gnn_hidden_size, |
|
|
496 |
nclass = 1, |
|
|
497 |
dropout = 0.6, |
|
|
498 |
init = 'uniform') |
|
|
499 |
### gnn's attention |
|
|
500 |
self.node_size = self.adj.shape[0] |
|
|
501 |
''' |
|
|
502 |
self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() \ |
|
|
503 |
if self.adj[i,j]==1 else None \ |
|
|
504 |
for j in range(self.node_size)]) \ |
|
|
505 |
for i in range(self.node_size)]) |
|
|
506 |
''' |
|
|
507 |
self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)]) |
|
|
508 |
# self.graph_attention_model_mat = nn.ModuleList([nn.ModuleList([self.gnn_attention() if self.adj[i,j]==1 else None for j in range(self.node_size)]) for i in range(self.node_size)]) |
|
|
509 |
|
|
|
510 |
''' |
|
|
511 |
nn.ModuleList([ nn.ModuleList([nn.Linear(3,2) for j in range(5)] + [None]) for i in range(3)]) |
|
|
512 |
''' |
|
|
513 |
|
|
|
514 |
self.device = device |
|
|
515 |
self = self.to(device) |
|
|
516 |
|
|
|
517 |
def generate_adj(self): |
|
|
518 |
##### consistent with HINT_nograph.forward |
|
|
519 |
lst = ["molecule", "disease", "criteria", 'INTERACTION', 'risk_disease', 'augment_interaction', 'A', 'D', 'M', 'E', 'T', 'PK', "final"] |
|
|
520 |
edge_lst = [("disease", "molecule"), ("disease", "criteria"), ("molecule", "criteria"), |
|
|
521 |
("disease", "INTERACTION"), ("molecule", "INTERACTION"), ("criteria", "INTERACTION"), |
|
|
522 |
("disease", "risk_disease"), ('risk_disease', 'augment_interaction'), ('INTERACTION', 'augment_interaction'), |
|
|
523 |
("molecule", "A"), ("molecule", "D"), ("molecule", "M"), ("molecule", "E"), ("molecule", "T"), |
|
|
524 |
('A', 'PK'), ('D', 'PK'), ('M', 'PK'), ('E', 'PK'), ('T', 'PK'), |
|
|
525 |
('augment_interaction', 'final'), ('PK', 'final')] |
|
|
526 |
adj = torch.zeros(len(lst), len(lst)) |
|
|
527 |
adj = torch.eye(len(lst)) * len(lst) |
|
|
528 |
num2str = {k:v for k,v in enumerate(lst)} |
|
|
529 |
str2num = {v:k for k,v in enumerate(lst)} |
|
|
530 |
for i,j in edge_lst: |
|
|
531 |
n1,n2 = str2num[i], str2num[j] |
|
|
532 |
adj[n1,n2] = 1 |
|
|
533 |
adj[n2,n1] = 1 |
|
|
534 |
return adj.to(self.device) |
|
|
535 |
|
|
|
536 |
def generate_attention_matrx(self, node_feature_mat): |
|
|
537 |
attention_mat = torch.zeros(self.node_size, self.node_size).to(self.device) |
|
|
538 |
for i in range(self.node_size): |
|
|
539 |
for j in range(self.node_size): |
|
|
540 |
if self.adj[i,j]!=1: |
|
|
541 |
continue |
|
|
542 |
feature = torch.cat([node_feature_mat[i].view(1,-1), node_feature_mat[j].view(1,-1)], 1) |
|
|
543 |
attention_model = self.graph_attention_model_mat[i][j] |
|
|
544 |
attention_mat[i,j] = torch.sigmoid(self.feed_lst_of_module(input_feature=feature, lst_of_module=attention_model)) |
|
|
545 |
return attention_mat |
|
|
546 |
|
|
|
547 |
##### self.global_embed_size*2 -> 1 |
|
|
548 |
def gnn_attention(self): |
|
|
549 |
highway_nn = Highway(size = self.global_embed_size*2, num_layers = self.highway_num_layer).to(self.device) |
|
|
550 |
highway_fc = nn.Linear(self.global_embed_size*2, 1).to(self.device) |
|
|
551 |
return nn.ModuleList([highway_nn, highway_fc]) |
|
|
552 |
|
|
|
553 |
def forward(self, smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix = False): |
|
|
554 |
embedding_lst = HINT_nograph.forward(self, smiles_lst2, icdcode_lst3, criteria_lst, if_gnn = True) |
|
|
555 |
### length is 13, each is 32,50 |
|
|
556 |
batch_size = embedding_lst[0].shape[0] |
|
|
557 |
output_lst = [] |
|
|
558 |
if return_attention_matrix: |
|
|
559 |
attention_mat_lst = [] |
|
|
560 |
for i in range(batch_size): |
|
|
561 |
node_feature_lst = [embedding[i].view(1,-1) for embedding in embedding_lst] |
|
|
562 |
node_feature_mat = torch.cat(node_feature_lst, 0) ### 13, 50 |
|
|
563 |
attention_mat = self.generate_attention_matrx(node_feature_mat) |
|
|
564 |
output = self.gnn(node_feature_mat, self.adj * attention_mat) |
|
|
565 |
output = output[-1].view(1,-1) |
|
|
566 |
output_lst.append(output) |
|
|
567 |
if return_attention_matrix: |
|
|
568 |
attention_mat_lst.append(attention_mat) |
|
|
569 |
output_mat = torch.cat(output_lst, 0) |
|
|
570 |
if not return_attention_matrix: |
|
|
571 |
return output_mat |
|
|
572 |
else: |
|
|
573 |
return output_mat, attention_mat_lst |
|
|
574 |
|
|
|
575 |
def interpret(self, complete_dataloader): |
|
|
576 |
from graph_visualize_interpret import data2graph |
|
|
577 |
from HINT.utils import replace_strange_symbol |
|
|
578 |
for nctid_lst, status_lst, why_stop_lst, label_vec, phase_lst, \ |
|
|
579 |
diseases_lst, icdcode_lst3, drugs_lst, smiles_lst2, criteria_lst in complete_dataloader: |
|
|
580 |
output, attention_mat_lst = self.forward(smiles_lst2, icdcode_lst3, criteria_lst, return_attention_matrix=True) |
|
|
581 |
output = output.view(-1) |
|
|
582 |
batch_size = len(nctid_lst) |
|
|
583 |
for i in range(batch_size): |
|
|
584 |
name = '__'.join([nctid_lst[i], status_lst[i], why_stop_lst[i], \ |
|
|
585 |
str(label_vec[i].item()), str(torch.sigmoid(output[i]).item())[:5], \ |
|
|
586 |
phase_lst[i], diseases_lst[i], drugs_lst[i]]) |
|
|
587 |
if len(name) > 150: |
|
|
588 |
name = name[:250] |
|
|
589 |
name = replace_strange_symbol(name) |
|
|
590 |
name = name.replace('__', '_') |
|
|
591 |
name = name.replace(' ', ' ') |
|
|
592 |
name = 'interpret_result/' + name + '.png' |
|
|
593 |
print(name) |
|
|
594 |
data2graph(attention_matrix = attention_mat_lst[i], adj = self.adj, save_name = name) |
|
|
595 |
|
|
|
596 |
def init_pretrain(self, admet_model): |
|
|
597 |
self.molecule_encoder = admet_model.molecule_encoder |
|
|
598 |
|
|
|
599 |
### generate attention matrix |
|
|
600 |
|
|
|
601 |
|
|
|
602 |
class Only_Molecule(Interaction): |
|
|
603 |
|
|
|
604 |
def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, |
|
|
605 |
global_embed_size, |
|
|
606 |
highway_num_layer, |
|
|
607 |
prefix_name, |
|
|
608 |
epoch = 20, |
|
|
609 |
lr = 3e-4, |
|
|
610 |
weight_decay = 0): |
|
|
611 |
super(Only_Molecule, self).__init__(molecule_encoder=molecule_encoder, |
|
|
612 |
disease_encoder=disease_encoder, |
|
|
613 |
protocol_encoder=protocol_encoder, |
|
|
614 |
global_embed_size = global_embed_size, |
|
|
615 |
highway_num_layer = highway_num_layer, |
|
|
616 |
prefix_name = prefix_name, |
|
|
617 |
epoch = epoch, |
|
|
618 |
lr = lr, |
|
|
619 |
weight_decay = weight_decay,) |
|
|
620 |
self.molecule2out = nn.Linear(self.global_embed_size,1) |
|
|
621 |
|
|
|
622 |
|
|
|
623 |
def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): |
|
|
624 |
molecule_embed = self.molecule_encoder.forward_smiles_lst_lst(smiles_lst2) |
|
|
625 |
return self.molecule2out(molecule_embed) |
|
|
626 |
|
|
|
627 |
class Only_Disease(Only_Molecule): |
|
|
628 |
|
|
|
629 |
def __init__(self, molecule_encoder, disease_encoder, protocol_encoder, |
|
|
630 |
global_embed_size, |
|
|
631 |
highway_num_layer, |
|
|
632 |
prefix_name, |
|
|
633 |
epoch = 20, |
|
|
634 |
lr = 3e-4, |
|
|
635 |
weight_decay = 0): |
|
|
636 |
super(Only_Disease, self).__init__(molecule_encoder = molecule_encoder, |
|
|
637 |
disease_encoder=disease_encoder, |
|
|
638 |
protocol_encoder=protocol_encoder, |
|
|
639 |
global_embed_size = global_embed_size, |
|
|
640 |
highway_num_layer = highway_num_layer, |
|
|
641 |
prefix_name = prefix_name, |
|
|
642 |
epoch = epoch, |
|
|
643 |
lr = lr, |
|
|
644 |
weight_decay = weight_decay,) |
|
|
645 |
self.disease2out = self.molecule2out |
|
|
646 |
|
|
|
647 |
|
|
|
648 |
def forward(self, smiles_lst2, icdcode_lst3, criteria_lst): |
|
|
649 |
icd_embed = self.disease_encoder.forward_code_lst3(icdcode_lst3) |
|
|
650 |
return self.disease2out(icd_embed) |
|
|
651 |
|
|
|
652 |
def dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, global_icd): |
|
|
653 |
## label_vec: (n,) |
|
|
654 |
y = label_vec |
|
|
655 |
|
|
|
656 |
num_icd = len(global_icd) |
|
|
657 |
from HINT.utils import smiles_lst2fp |
|
|
658 |
fp_lst = [smiles_lst2fp(smiles_lst).reshape(1,-1) for smiles_lst in smiles_lst2] |
|
|
659 |
fp_mat = np.concatenate(fp_lst, 0) |
|
|
660 |
# fp_mat = torch.from_numpy(fp_mat) ### (n,2048) |
|
|
661 |
|
|
|
662 |
icdcode_lst = [] |
|
|
663 |
for lst2 in icdcode_lst3: |
|
|
664 |
lst = list(reduce(lambda x,y:x+y, lst2)) |
|
|
665 |
lst = [i.split('.')[0] for i in lst] |
|
|
666 |
lst = set(lst) |
|
|
667 |
icd_feature = np.zeros((1,num_icd), np.int32) |
|
|
668 |
for ele in lst: |
|
|
669 |
if ele in global_icd: |
|
|
670 |
idx = global_icd.index(ele) |
|
|
671 |
icd_feature[0,idx] = 1 |
|
|
672 |
icdcode_lst.append(icd_feature) |
|
|
673 |
icdcode_mat = np.concatenate(icdcode_lst, 0) |
|
|
674 |
X = np.concatenate([fp_mat, icdcode_mat], 1) |
|
|
675 |
X = torch.from_numpy(X) |
|
|
676 |
X = X.float() |
|
|
677 |
# icdcode_mat = torch.from_numpy(icdcode_mat) |
|
|
678 |
|
|
|
679 |
# X = torch.cat([fp_mat, icdcode_mat], 1) |
|
|
680 |
return X, y |
|
|
681 |
|
|
|
682 |
|
|
|
683 |
class FFNN(nn.Sequential): |
|
|
684 |
def __init__(self, molecule_dim, diseasecode_dim, |
|
|
685 |
global_icd, |
|
|
686 |
protocol_dim = 0, |
|
|
687 |
prefix_name = 'FFNN', |
|
|
688 |
epoch = 10, |
|
|
689 |
lr = 3e-4, |
|
|
690 |
weight_decay = 0, |
|
|
691 |
): |
|
|
692 |
super(FFNN, self).__init__() |
|
|
693 |
self.molecule_dim = molecule_dim |
|
|
694 |
self.diseasecode_dim = diseasecode_dim |
|
|
695 |
self.protocol_dim = protocol_dim |
|
|
696 |
self.prefix_name = prefix_name |
|
|
697 |
self.epoch = epoch |
|
|
698 |
self.lr = lr |
|
|
699 |
self.weight_decay = weight_decay |
|
|
700 |
self.global_icd = global_icd |
|
|
701 |
self.num_icd = len(global_icd) |
|
|
702 |
|
|
|
703 |
self.fc_dims = [self.molecule_dim + self.diseasecode_dim + self.protocol_dim, 2000, 1000, 200, 50, 1] |
|
|
704 |
self.fc_layers = nn.ModuleList([nn.Linear(v,self.fc_dims[i+1]) for i,v in enumerate(self.fc_dims[:-1])]) |
|
|
705 |
self.loss = nn.BCEWithLogitsLoss() |
|
|
706 |
self.save_name = prefix_name |
|
|
707 |
|
|
|
708 |
def forward(self, X): |
|
|
709 |
for i in range(len(self.fc_layers) - 1): |
|
|
710 |
fc_layer = self.fc_layers[i] |
|
|
711 |
X = fc_layer(X) |
|
|
712 |
last_layer = self.fc_layers[-1] |
|
|
713 |
pred = F.sigmoid(last_layer(X)) |
|
|
714 |
return pred |
|
|
715 |
|
|
|
716 |
def learn(self, train_loader, valid_loader, test_loader): |
|
|
717 |
opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) |
|
|
718 |
train_loss_record = [] |
|
|
719 |
valid_loss = self.test(valid_loader, return_loss=True) |
|
|
720 |
valid_loss_record = [valid_loss] |
|
|
721 |
best_valid_loss = valid_loss |
|
|
722 |
best_model = deepcopy(self) |
|
|
723 |
|
|
|
724 |
for ep in tqdm(range(self.epoch)): |
|
|
725 |
for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in train_loader: |
|
|
726 |
X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) |
|
|
727 |
output = self.forward(X).view(-1) #### 32, 1 -> 32, || label_vec 32, |
|
|
728 |
loss = self.loss(output, label_vec.float()) |
|
|
729 |
train_loss_record.append(loss.item()) |
|
|
730 |
opt.zero_grad() |
|
|
731 |
loss.backward() |
|
|
732 |
opt.step() |
|
|
733 |
valid_loss = self.test(valid_loader, return_loss=True) |
|
|
734 |
valid_loss_record.append(valid_loss) |
|
|
735 |
if valid_loss < best_valid_loss: |
|
|
736 |
best_valid_loss = valid_loss |
|
|
737 |
best_model = deepcopy(self) |
|
|
738 |
|
|
|
739 |
self.plot_learning_curve(train_loss_record, valid_loss_record) |
|
|
740 |
self = deepcopy(best_model) |
|
|
741 |
auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio = self.test(test_loader, return_loss = False, validloader = valid_loader) |
|
|
742 |
|
|
|
743 |
def evaluation(self, predict_all, label_all, threshold = 0.5): |
|
|
744 |
import pickle, os |
|
|
745 |
from sklearn.metrics import roc_curve, precision_recall_curve |
|
|
746 |
with open("predict_label.txt", 'w') as fout: |
|
|
747 |
for i,j in zip(predict_all, label_all): |
|
|
748 |
fout.write(str(i)[:4] + '\t' + str(j)[:4]+'\n') |
|
|
749 |
auc_score = roc_auc_score(label_all, predict_all) |
|
|
750 |
figure_folder = "figure" |
|
|
751 |
#### ROC-curve |
|
|
752 |
fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1) |
|
|
753 |
# roc_curve =plt.figure() |
|
|
754 |
# plt.plot(fpr,tpr,'-',label=self.save_name + ' ROC Curve ') |
|
|
755 |
# plt.legend(fontsize = 15) |
|
|
756 |
#plt.savefig(os.path.join(figure_folder,name+"_roc_curve.png")) |
|
|
757 |
#### PR-curve |
|
|
758 |
precision, recall, thresholds = precision_recall_curve(label_all, predict_all) |
|
|
759 |
# plt.plot(recall,precision, label = self.save_name + ' PR Curve') |
|
|
760 |
# plt.legend(fontsize = 15) |
|
|
761 |
# plt.savefig(os.path.join(figure_folder,self.save_name + "_pr_curve.png")) |
|
|
762 |
label_all = [int(i) for i in label_all] |
|
|
763 |
float2binary = lambda x:0 if x<threshold else 1 |
|
|
764 |
predict_all = list(map(float2binary, predict_all)) |
|
|
765 |
f1score = f1_score(label_all, predict_all) |
|
|
766 |
prauc_score = average_precision_score(label_all, predict_all) |
|
|
767 |
# print(predict_all) |
|
|
768 |
precision = precision_score(label_all, predict_all) |
|
|
769 |
recall = recall_score(label_all, predict_all) |
|
|
770 |
accuracy = accuracy_score(label_all, predict_all) |
|
|
771 |
predict_1_ratio = sum(predict_all) / len(predict_all) |
|
|
772 |
label_1_ratio = sum(label_all) / len(label_all) |
|
|
773 |
return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio |
|
|
774 |
|
|
|
775 |
def generate_predict(self, dataloader): |
|
|
776 |
whole_loss = 0 |
|
|
777 |
label_all, predict_all = [], [] |
|
|
778 |
for nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst in dataloader: |
|
|
779 |
X, _ = dataloader2Xy(nctid_lst, label_vec, smiles_lst2, icdcode_lst3, criteria_lst, self.global_icd) |
|
|
780 |
output = self.forward(X).view(-1) |
|
|
781 |
loss = self.loss(output, label_vec.float()) |
|
|
782 |
whole_loss += loss.item() |
|
|
783 |
predict_all.extend([i.item() for i in torch.sigmoid(output)]) |
|
|
784 |
label_all.extend([i.item() for i in label_vec]) |
|
|
785 |
|
|
|
786 |
return whole_loss, predict_all, label_all |
|
|
787 |
|
|
|
788 |
def bootstrap_test(self, dataloader, validloader = None, sample_num = 20): |
|
|
789 |
best_threshold = 0.5 |
|
|
790 |
# if validloader is not None: |
|
|
791 |
# best_threshold = self.select_threshold_for_binary(validloader) |
|
|
792 |
self.eval() |
|
|
793 |
whole_loss, predict_all, label_all = self.generate_predict(dataloader) |
|
|
794 |
from HINT.utils import plot_hist |
|
|
795 |
plt.clf() |
|
|
796 |
prefix_name = "./figure/" + self.save_name |
|
|
797 |
plot_hist(prefix_name, predict_all, label_all) |
|
|
798 |
def bootstrap(length, sample_num): |
|
|
799 |
idx = [i for i in range(length)] |
|
|
800 |
from random import choices |
|
|
801 |
bootstrap_idx = [choices(idx, k = length) for i in range(sample_num)] |
|
|
802 |
return bootstrap_idx |
|
|
803 |
results_lst = [] |
|
|
804 |
bootstrap_idx_lst = bootstrap(len(predict_all), sample_num = sample_num) |
|
|
805 |
for bootstrap_idx in bootstrap_idx_lst: |
|
|
806 |
bootstrap_label = [label_all[idx] for idx in bootstrap_idx] |
|
|
807 |
bootstrap_predict = [predict_all[idx] for idx in bootstrap_idx] |
|
|
808 |
results = self.evaluation(bootstrap_predict, bootstrap_label, threshold = best_threshold) |
|
|
809 |
results_lst.append(results) |
|
|
810 |
self.train() |
|
|
811 |
auc = [results[0] for results in results_lst] |
|
|
812 |
f1score = [results[1] for results in results_lst] |
|
|
813 |
prauc_score = [results[2] for results in results_lst] |
|
|
814 |
print("PR-AUC mean: "+str(np.mean(prauc_score))[:6], "std: "+str(np.std(prauc_score))[:6]) |
|
|
815 |
print("F1 mean: "+str(np.mean(f1score))[:6], "std: "+str(np.std(f1score))[:6]) |
|
|
816 |
print("ROC-AUC mean: "+ str(np.mean(auc))[:6], "std: " + str(np.std(auc))[:6]) |
|
|
817 |
|
|
|
818 |
def test(self, dataloader, return_loss = True, validloader=None): |
|
|
819 |
# if validloader is not None: |
|
|
820 |
# best_threshold = self.select_threshold_for_binary(validloader) |
|
|
821 |
self.eval() |
|
|
822 |
best_threshold = 0.5 |
|
|
823 |
whole_loss, predict_all, label_all = self.generate_predict(dataloader) |
|
|
824 |
# from HINT.utils import plot_hist |
|
|
825 |
# plt.clf() |
|
|
826 |
# prefix_name = "./figure/" + self.save_name |
|
|
827 |
# plot_hist(prefix_name, predict_all, label_all) |
|
|
828 |
self.train() |
|
|
829 |
if return_loss: |
|
|
830 |
return whole_loss |
|
|
831 |
else: |
|
|
832 |
print_num = 5 |
|
|
833 |
auc_score, f1score, prauc_score, precision, recall, accuracy, \ |
|
|
834 |
predict_1_ratio, label_1_ratio = self.evaluation(predict_all, label_all, threshold = best_threshold) |
|
|
835 |
print("ROC AUC: " + str(auc_score)[:print_num] + "\nF1: " + str(f1score)[:print_num] \ |
|
|
836 |
+ "\nPR-AUC: " + str(prauc_score)[:print_num] \ |
|
|
837 |
+ "\nPrecision: " + str(precision)[:print_num] \ |
|
|
838 |
+ "\nrecall: "+str(recall)[:print_num] + "\naccuracy: "+str(accuracy)[:print_num] \ |
|
|
839 |
+ "\npredict 1 ratio: " + str(predict_1_ratio)[:print_num] \ |
|
|
840 |
+ "\nlabel 1 ratio: " + str(label_1_ratio)[:print_num]) |
|
|
841 |
return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio |
|
|
842 |
|
|
|
843 |
def plot_learning_curve(self, train_loss_record, valid_loss_record): |
|
|
844 |
plt.plot(train_loss_record) |
|
|
845 |
plt.savefig("./figure/" + self.save_name + '_train_loss.jpg') |
|
|
846 |
plt.clf() |
|
|
847 |
plt.plot(valid_loss_record) |
|
|
848 |
plt.savefig("./figure/" + self.save_name + '_valid_loss.jpg') |
|
|
849 |
plt.clf() |
|
|
850 |
|
|
|
851 |
|
|
|
852 |
class ADMET(nn.Sequential): |
|
|
853 |
def __init__(self, mpnn_model, device): |
|
|
854 |
super(ADMET, self).__init__() |
|
|
855 |
self.num = 5 |
|
|
856 |
self.mpnn_model = mpnn_model |
|
|
857 |
self.device = device |
|
|
858 |
self.mpnn_dim = mpnn_model.mpnn_hidden_size |
|
|
859 |
self.admet_model = [] |
|
|
860 |
self.global_embed_size = self.mpnn_dim |
|
|
861 |
self.highway_num_layer = 2 |
|
|
862 |
for i in range(5): |
|
|
863 |
admet_fc = nn.Linear(self.mpnn_model.mpnn_hidden_size, self.global_embed_size).to(device) |
|
|
864 |
admet_highway = Highway(self.global_embed_size, self.highway_num_layer).to(device) |
|
|
865 |
self.admet_model.append(nn.ModuleList([admet_fc, admet_highway])) |
|
|
866 |
self.admet_model = nn.ModuleList(self.admet_model) |
|
|
867 |
|
|
|
868 |
self.admet_pred = nn.ModuleList([nn.Linear(self.global_embed_size,1).to(device) for i in range(5)]) |
|
|
869 |
self.f = F.relu |
|
|
870 |
|
|
|
871 |
self.device = device |
|
|
872 |
self = self.to(device) |
|
|
873 |
|
|
|
874 |
def feed_lst_of_module(self, input_feature, lst_of_module): |
|
|
875 |
x = input_feature |
|
|
876 |
for single_module in lst_of_module: |
|
|
877 |
x = self.f(single_module(x)) |
|
|
878 |
return x |
|
|
879 |
|
|
|
880 |
def forward(self, smiles_lst, idx): |
|
|
881 |
assert idx in list(range(5)) |
|
|
882 |
''' |
|
|
883 |
xxxxxxxxxxxx |
|
|
884 |
''' |
|
|
885 |
embeds = self.mpnn_model.forward_smiles_lst_lst(smiles_lst) |
|
|
886 |
embeds = self.feed_lst_of_module(embeds, self.admet_model[idx]) |
|
|
887 |
output = self.admet_pred[idx](embeds) |
|
|
888 |
return output |
|
|
889 |
|
|
|
890 |
def test(self, valid_loader): |
|
|
891 |
pass |
|
|
892 |
|
|
|
893 |
def learn(self, train_loader, valid_loader, idx): |
|
|
894 |
opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) |
|
|
895 |
train_loss_record = [] |
|
|
896 |
valid_loss = self.test(valid_loader, return_loss=True) |
|
|
897 |
valid_loss_record = [valid_loss] |
|
|
898 |
best_valid_loss = valid_loss |
|
|
899 |
best_model = deepcopy(self) |
|
|
900 |
|
|
|
901 |
for ep in tqdm(range(self.epoch)): |
|
|
902 |
for smiles_lst in train_loader: |
|
|
903 |
output = self.forward(smiles_lst).view(-1) #### 32, 1 -> 32, || label_vec 32, |
|
|
904 |
loss = self.loss(output, label_vec.float()) |
|
|
905 |
train_loss_record.append(loss.item()) |
|
|
906 |
opt.zero_grad() |
|
|
907 |
loss.backward() |
|
|
908 |
opt.step() |
|
|
909 |
valid_loss = self.test(valid_loader, return_loss=True) |
|
|
910 |
valid_loss_record.append(valid_loss) |
|
|
911 |
if valid_loss < best_valid_loss: |
|
|
912 |
best_valid_loss = valid_loss |
|
|
913 |
best_model = deepcopy(self) |
|
|
914 |
|
|
|
915 |
self = deepcopy(best_model) |