Diff of /benchmark/dataloader.py [000000] .. [bc9e98]

Switch to unified view

a b/benchmark/dataloader.py
1
'''
2
3
(I). Trial_Dataset for prediction
4
(II). Trial_Dataset_Complete for interpretation
5
(III). SMILES lst 
6
(IV). disease lst icd-code 
7
8
'''
9
10
import torch, csv, os
11
from torch.utils import data 
12
from torch.utils.data.dataloader import default_collate
13
from molecule_encode import smiles2mpnnfeature
14
from protocol_encode import protocol2feature, load_sentence_2_vec
15
16
sentence2vec = load_sentence_2_vec() 
17
18
class Trial_Dataset(data.Dataset):
19
    def __init__(self, nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst):
20
        self.nctid_lst = nctid_lst 
21
        self.label_lst = label_lst 
22
        self.smiles_lst = smiles_lst 
23
        self.icdcode_lst = icdcode_lst 
24
        self.criteria_lst = criteria_lst 
25
26
    def __len__(self):
27
        return len(self.nctid_lst)
28
29
    def __getitem__(self, index):
30
        return self.nctid_lst[index], self.label_lst[index], self.smiles_lst[index], self.icdcode_lst[index], self.criteria_lst[index]
31
    #### smiles_lst[index] is list of smiles
32
33
class Trial_Dataset_Complete(Trial_Dataset):
34
    def __init__(self, nctid_lst, status_lst, why_stop_lst, label_lst, phase_lst, 
35
                       diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst):
36
        Trial_Dataset.__init__(self, nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst)
37
        self.status_lst = status_lst 
38
        self.why_stop_lst = why_stop_lst 
39
        self.phase_lst = phase_lst 
40
        self.diseases_lst = diseases_lst 
41
        self.drugs_lst = drugs_lst 
42
43
    def __getitem__(self, index):
44
        return self.nctid_lst[index], self.status_lst[index], self.why_stop_lst[index], self.label_lst[index], self.phase_lst[index], \
45
               self.diseases_lst[index], self.icdcode_lst[index], self.drugs_lst[index], self.smiles_lst[index], self.criteria_lst[index]
46
47
48
49
class ADMET_Dataset(data.Dataset):
50
    def __init__(self, smiles_lst, label_lst):
51
        self.smiles_lst = smiles_lst 
52
        self.label_lst = label_lst 
53
    
54
    def __len__(self):
55
        return len(self.smiles_lst)
56
57
    def __getitem__(self, index):
58
        return self.smiles_lst[index], self.label_lst[index]
59
60
def admet_collate_fn(x):
61
    smiles_lst = [i[0] for i in x]
62
    label_vec = default_collate([int(i[1]) for i in x])  ### shape n, 
63
    return [smiles_lst, label_vec]
64
65
66
def smiles_txt_to_lst(text):
67
    """
68
        "['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=CC=C12', 'CNCCC=C1C2=CC=CC=C2CCC2=CC=CC=C12']" 
69
    """
70
    text = text[1:-1]
71
    lst = [i.strip()[1:-1] for i in text.split(',')]
72
    return lst 
73
74
def icdcode_text_2_lst_of_lst(text):
75
    text = text[2:-2]
76
    lst_lst = []
77
    for i in text.split('", "'):
78
        i = i[1:-1]
79
        lst_lst.append([j.strip()[1:-1] for j in i.split(',')])
80
    return lst_lst 
81
82
def trial_collate_fn(x):
83
    nctid_lst = [i[0] for i in x]     ### ['NCT00604461', ..., 'NCT00788957'] 
84
    label_vec = default_collate([int(i[1]) for i in x])  ### shape n, 
85
    smiles_lst = [smiles_txt_to_lst(i[2]) for i in x]
86
    icdcode_lst = [icdcode_text_2_lst_of_lst(i[3]) for i in x]
87
    criteria_lst = [protocol2feature(i[4], sentence2vec) for i in x]
88
    return [nctid_lst, label_vec, smiles_lst, icdcode_lst, criteria_lst]
89
90
def trial_complete_collate_fn(x):
91
    nctid_lst = [i[0] for i in x]     ### ['NCT00604461', ..., 'NCT00788957'] 
92
    status_lst = [i[1] for i in x]
93
    why_stop_lst = [i[2] for i in x]
94
    label_vec = default_collate([int(i[3]) for i in x])  ### shape n, 
95
    phase_lst = [i[4] for i in x]
96
    diseases_lst = [i[5] for i in x]
97
    icdcode_lst = [icdcode_text_2_lst_of_lst(i[6]) for i in x]
98
    drugs_lst = [i[7] for i in x]
99
    smiles_lst = [smiles_txt_to_lst(i[8]) for i in x]
100
    criteria_lst = [protocol2feature(i[9], sentence2vec) for i in x]
101
    return [nctid_lst, status_lst, why_stop_lst, label_vec, phase_lst, diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst]
102
103
def csv_three_feature_2_dataloader(csvfile, shuffle, batch_size):
104
    with open(csvfile, 'r') as csvfile:
105
        rows = list(csv.reader(csvfile, delimiter=','))[1:]
106
    ## nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria
107
    nctid_lst = [row[0] for row in rows]
108
    label_lst = [row[3] for row in rows]
109
    icdcode_lst = [row[6] for row in rows]
110
    drugs_lst = [row[7] for row in rows]
111
    smiles_lst = [row[8] for row in rows]
112
    criteria_lst = [row[9] for row in rows] 
113
    dataset = Trial_Dataset(nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst)
114
    data_loader = data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, collate_fn = trial_collate_fn)
115
    return data_loader
116
117
def csv_three_feature_2_complete_dataloader(csvfile, shuffle, batch_size):
118
    with open(csvfile, 'r') as csvfile:
119
        rows = list(csv.reader(csvfile, delimiter=','))[1:] 
120
    nctid_lst = [row[0] for row in rows]
121
    status_lst = [row[1] for row in rows]
122
    why_stop_lst = [row[2] for row in rows]
123
    label_lst = [row[3] for row in rows]
124
    phase_lst = [row[4] for row in rows]
125
    diseases_lst = [row[5] for row in rows]
126
    icdcode_lst = [row[6] for row in rows]
127
    drugs_lst = [row[7] for row in rows]
128
    smiles_lst = [row[8] for row in rows]
129
    new_drugs_lst, new_smiles_lst = [], []
130
    criteria_lst = [row[9] for row in rows] 
131
    dataset = Trial_Dataset_Complete(nctid_lst, status_lst, why_stop_lst, label_lst, phase_lst, 
132
                                     diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst)
133
    data_loader = data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, collate_fn = trial_complete_collate_fn)
134
    return data_loader 
135
136
137
138
139
140
141
def smiles_txt_to_2lst(smiles_txt_file):
142
    with open(smiles_txt_file, 'r') as fin:
143
        lines = fin.readlines() 
144
    smiles_lst = [line.split()[0] for line in lines]
145
    label_lst = [int(line.split()[1]) for line in lines]
146
    return smiles_lst, label_lst 
147
148
def generate_admet_dataloader_lst(batch_size):
149
    datafolder = "data/ADMET/cooked/"
150
    name_lst = ["absorption", 'distribution', 'metabolism', 'excretion', 'toxicity']
151
    dataloader_lst = []
152
    for i,name in enumerate(name_lst):
153
        train_file = os.path.join(datafolder, name + '_train.txt')
154
        test_file = os.path.join(datafolder, name +'_valid.txt')
155
        train_smiles_lst, train_label_lst = smiles_txt_to_2lst(train_file)
156
        test_smiles_lst, test_label_lst = smiles_txt_to_2lst(test_file)
157
        train_dataset = ADMET_Dataset(smiles_lst = train_smiles_lst, label_lst = train_label_lst)
158
        test_dataset = ADMET_Dataset(smiles_lst = test_smiles_lst, label_lst = test_label_lst)
159
        train_dataloader = data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
160
        test_dataloader = data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
161
        dataloader_lst.append((train_dataloader, test_dataloader))
162
    return dataloader_lst 
163
164
# ## x is a list, len(x)=batch_size, x[i] is tuple, len(x[0])=5  
165
# def mpnn_feature_collate_func(x): 
166
#   return [torch.cat([x[j][i] for j in range(len(x))], 0) for i in range(len(x[0]))]
167
168
169
# def mpnn_collate_func(x):
170
#   #print("len(x) is ", len(x)) ## batch_size 
171
#   #print("len(x[0]) is ", len(x[0])) ## 3--- data_process_loader.__getitem__ 
172
#   mpnn_feature = [i[0] for i in x]
173
#   #print("len(mpnn_feature)", len(mpnn_feature), "len(mpnn_feature[0])", len(mpnn_feature[0]))
174
#   mpnn_feature = mpnn_feature_collate_func(mpnn_feature)
175
#   from torch.utils.data.dataloader import default_collate
176
#   x_remain = [i[1:] for i in x]
177
#   x_remain_collated = default_collate(x_remain)
178
#   return [mpnn_feature] + x_remain_collated
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200