|
a |
|
b/HINT/dataloader.py |
|
|
1 |
''' |
|
|
2 |
|
|
|
3 |
(I). Trial_Dataset for prediction |
|
|
4 |
(II). Trial_Dataset_Complete for interpretation |
|
|
5 |
(III). SMILES lst |
|
|
6 |
(IV). disease lst icd-code |
|
|
7 |
|
|
|
8 |
''' |
|
|
9 |
|
|
|
10 |
import torch, csv, os |
|
|
11 |
from torch.utils import data |
|
|
12 |
from torch.utils.data.dataloader import default_collate |
|
|
13 |
from HINT.molecule_encode import smiles2mpnnfeature |
|
|
14 |
from HINT.protocol_encode import protocol2feature, load_sentence_2_vec |
|
|
15 |
|
|
|
16 |
sentence2vec = load_sentence_2_vec() |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class Trial_Dataset(data.Dataset): |
|
|
20 |
def __init__(self, nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst): |
|
|
21 |
self.nctid_lst = nctid_lst |
|
|
22 |
self.label_lst = label_lst |
|
|
23 |
self.smiles_lst = smiles_lst |
|
|
24 |
self.icdcode_lst = icdcode_lst |
|
|
25 |
self.criteria_lst = criteria_lst |
|
|
26 |
|
|
|
27 |
def __len__(self): |
|
|
28 |
return len(self.nctid_lst) |
|
|
29 |
|
|
|
30 |
def __getitem__(self, index): |
|
|
31 |
return self.nctid_lst[index], self.label_lst[index], self.smiles_lst[index], self.icdcode_lst[index], self.criteria_lst[index] |
|
|
32 |
#### smiles_lst[index] is list of smiles |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
class Trial_Dataset_Complete(Trial_Dataset): |
|
|
36 |
def __init__(self, nctid_lst, status_lst, why_stop_lst, label_lst, phase_lst, |
|
|
37 |
diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst): |
|
|
38 |
Trial_Dataset.__init__(self, nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst) |
|
|
39 |
self.status_lst = status_lst |
|
|
40 |
self.why_stop_lst = why_stop_lst |
|
|
41 |
self.phase_lst = phase_lst |
|
|
42 |
self.diseases_lst = diseases_lst |
|
|
43 |
self.drugs_lst = drugs_lst |
|
|
44 |
|
|
|
45 |
def __getitem__(self, index): |
|
|
46 |
return self.nctid_lst[index], self.status_lst[index], self.why_stop_lst[index], self.label_lst[index], self.phase_lst[index], \ |
|
|
47 |
self.diseases_lst[index], self.icdcode_lst[index], self.drugs_lst[index], self.smiles_lst[index], self.criteria_lst[index] |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
class ADMET_Dataset(data.Dataset): |
|
|
51 |
def __init__(self, smiles_lst, label_lst): |
|
|
52 |
self.smiles_lst = smiles_lst |
|
|
53 |
self.label_lst = label_lst |
|
|
54 |
|
|
|
55 |
def __len__(self): |
|
|
56 |
return len(self.smiles_lst) |
|
|
57 |
|
|
|
58 |
def __getitem__(self, index): |
|
|
59 |
return self.smiles_lst[index], self.label_lst[index] |
|
|
60 |
|
|
|
61 |
def admet_collate_fn(x): |
|
|
62 |
smiles_lst = [i[0] for i in x] |
|
|
63 |
label_vec = default_collate([int(i[1]) for i in x]) ### shape n, |
|
|
64 |
return [smiles_lst, label_vec] |
|
|
65 |
|
|
|
66 |
def smiles_txt_to_lst(text): |
|
|
67 |
""" |
|
|
68 |
"['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=CC=C12', 'CNCCC=C1C2=CC=CC=C2CCC2=CC=CC=C12']" |
|
|
69 |
""" |
|
|
70 |
text = text[1:-1] |
|
|
71 |
lst = [i.strip()[1:-1] for i in text.split(',')] |
|
|
72 |
return lst |
|
|
73 |
|
|
|
74 |
def icdcode_text_2_lst_of_lst(text): |
|
|
75 |
text = text[2:-2] |
|
|
76 |
lst_lst = [] |
|
|
77 |
for i in text.split('", "'): |
|
|
78 |
i = i[1:-1] |
|
|
79 |
lst_lst.append([j.strip()[1:-1] for j in i.split(',')]) |
|
|
80 |
return lst_lst |
|
|
81 |
|
|
|
82 |
def trial_collate_fn(x): |
|
|
83 |
nctid_lst = [i[0] for i in x] ### ['NCT00604461', ..., 'NCT00788957'] |
|
|
84 |
label_vec = default_collate([int(i[1]) for i in x]) ### shape n, |
|
|
85 |
smiles_lst = [smiles_txt_to_lst(i[2]) for i in x] |
|
|
86 |
icdcode_lst = [icdcode_text_2_lst_of_lst(i[3]) for i in x] |
|
|
87 |
criteria_lst = [protocol2feature(i[4], sentence2vec) for i in x] |
|
|
88 |
return [nctid_lst, label_vec, smiles_lst, icdcode_lst, criteria_lst] |
|
|
89 |
|
|
|
90 |
def trial_complete_collate_fn(x): |
|
|
91 |
nctid_lst = [i[0] for i in x] ### ['NCT00604461', ..., 'NCT00788957'] |
|
|
92 |
status_lst = [i[1] for i in x] |
|
|
93 |
why_stop_lst = [i[2] for i in x] |
|
|
94 |
label_vec = default_collate([int(i[3]) for i in x]) ### shape n, |
|
|
95 |
phase_lst = [i[4] for i in x] |
|
|
96 |
diseases_lst = [i[5] for i in x] |
|
|
97 |
icdcode_lst = [icdcode_text_2_lst_of_lst(i[6]) for i in x] |
|
|
98 |
drugs_lst = [i[7] for i in x] |
|
|
99 |
smiles_lst = [smiles_txt_to_lst(i[8]) for i in x] |
|
|
100 |
criteria_lst = [protocol2feature(i[9], sentence2vec) for i in x] |
|
|
101 |
return [nctid_lst, status_lst, why_stop_lst, label_vec, phase_lst, diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst] |
|
|
102 |
|
|
|
103 |
def csv_three_feature_2_dataloader(csvfile, shuffle, batch_size): |
|
|
104 |
with open(csvfile, 'r') as csvfile: |
|
|
105 |
rows = list(csv.reader(csvfile, delimiter=','))[1:] |
|
|
106 |
## nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria |
|
|
107 |
nctid_lst = [row[0] for row in rows] |
|
|
108 |
label_lst = [row[3] for row in rows] |
|
|
109 |
icdcode_lst = [row[6] for row in rows] |
|
|
110 |
drugs_lst = [row[7] for row in rows] |
|
|
111 |
smiles_lst = [row[8] for row in rows] |
|
|
112 |
criteria_lst = [row[9] for row in rows] |
|
|
113 |
dataset = Trial_Dataset(nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst) |
|
|
114 |
data_loader = data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, collate_fn = trial_collate_fn) |
|
|
115 |
return data_loader |
|
|
116 |
|
|
|
117 |
def csv_three_feature_2_complete_dataloader(csvfile, shuffle, batch_size): |
|
|
118 |
with open(csvfile, 'r') as csvfile: |
|
|
119 |
rows = list(csv.reader(csvfile, delimiter=','))[1:] |
|
|
120 |
nctid_lst = [row[0] for row in rows] |
|
|
121 |
status_lst = [row[1] for row in rows] |
|
|
122 |
why_stop_lst = [row[2] for row in rows] |
|
|
123 |
label_lst = [row[3] for row in rows] |
|
|
124 |
phase_lst = [row[4] for row in rows] |
|
|
125 |
diseases_lst = [row[5] for row in rows] |
|
|
126 |
icdcode_lst = [row[6] for row in rows] |
|
|
127 |
drugs_lst = [row[7] for row in rows] |
|
|
128 |
smiles_lst = [row[8] for row in rows] |
|
|
129 |
new_drugs_lst, new_smiles_lst = [], [] |
|
|
130 |
criteria_lst = [row[9] for row in rows] |
|
|
131 |
dataset = Trial_Dataset_Complete(nctid_lst, status_lst, why_stop_lst, label_lst, phase_lst, |
|
|
132 |
diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst) |
|
|
133 |
data_loader = data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, collate_fn = trial_complete_collate_fn) |
|
|
134 |
return data_loader |
|
|
135 |
|
|
|
136 |
def smiles_txt_to_2lst(smiles_txt_file): |
|
|
137 |
with open(smiles_txt_file, 'r') as fin: |
|
|
138 |
lines = fin.readlines() |
|
|
139 |
smiles_lst = [line.split()[0] for line in lines] |
|
|
140 |
label_lst = [int(line.split()[1]) for line in lines] |
|
|
141 |
return smiles_lst, label_lst |
|
|
142 |
|
|
|
143 |
def generate_admet_dataloader_lst(batch_size): |
|
|
144 |
datafolder = "data/ADMET/cooked/" |
|
|
145 |
name_lst = ["absorption", 'distribution', 'metabolism', 'excretion', 'toxicity'] |
|
|
146 |
dataloader_lst = [] |
|
|
147 |
for i,name in enumerate(name_lst): |
|
|
148 |
train_file = os.path.join(datafolder, name + '_train.txt') |
|
|
149 |
test_file = os.path.join(datafolder, name +'_valid.txt') |
|
|
150 |
train_smiles_lst, train_label_lst = smiles_txt_to_2lst(train_file) |
|
|
151 |
test_smiles_lst, test_label_lst = smiles_txt_to_2lst(test_file) |
|
|
152 |
train_dataset = ADMET_Dataset(smiles_lst = train_smiles_lst, label_lst = train_label_lst) |
|
|
153 |
test_dataset = ADMET_Dataset(smiles_lst = test_smiles_lst, label_lst = test_label_lst) |
|
|
154 |
train_dataloader = data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True) |
|
|
155 |
test_dataloader = data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False) |
|
|
156 |
dataloader_lst.append((train_dataloader, test_dataloader)) |
|
|
157 |
return dataloader_lst |
|
|
158 |
|
|
|
159 |
# ## x is a list, len(x)=batch_size, x[i] is tuple, len(x[0])=5 |
|
|
160 |
# def mpnn_feature_collate_func(x): |
|
|
161 |
# return [torch.cat([x[j][i] for j in range(len(x))], 0) for i in range(len(x[0]))] |
|
|
162 |
|
|
|
163 |
|
|
|
164 |
# def mpnn_collate_func(x): |
|
|
165 |
# #print("len(x) is ", len(x)) ## batch_size |
|
|
166 |
# #print("len(x[0]) is ", len(x[0])) ## 3--- data_process_loader.__getitem__ |
|
|
167 |
# mpnn_feature = [i[0] for i in x] |
|
|
168 |
# #print("len(mpnn_feature)", len(mpnn_feature), "len(mpnn_feature[0])", len(mpnn_feature[0])) |
|
|
169 |
# mpnn_feature = mpnn_feature_collate_func(mpnn_feature) |
|
|
170 |
# from torch.utils.data.dataloader import default_collate |
|
|
171 |
# x_remain = [i[1:] for i in x] |
|
|
172 |
# x_remain_collated = default_collate(x_remain) |
|
|
173 |
# return [mpnn_feature] + x_remain_collated |
|
|
174 |
|
|
|
175 |
|
|
|
176 |
|
|
|
177 |
|
|
|
178 |
|
|
|
179 |
|
|
|
180 |
|
|
|
181 |
|
|
|
182 |
|
|
|
183 |
|
|
|
184 |
|
|
|
185 |
|
|
|
186 |
|
|
|
187 |
|
|
|
188 |
|
|
|
189 |
|
|
|
190 |
|
|
|
191 |
|
|
|
192 |
|
|
|
193 |
|
|
|
194 |
|
|
|
195 |
|