--- a +++ b/HINT/dataloader.py @@ -0,0 +1,195 @@ +''' + +(I). Trial_Dataset for prediction +(II). Trial_Dataset_Complete for interpretation +(III). SMILES lst +(IV). disease lst icd-code + +''' + +import torch, csv, os +from torch.utils import data +from torch.utils.data.dataloader import default_collate +from HINT.molecule_encode import smiles2mpnnfeature +from HINT.protocol_encode import protocol2feature, load_sentence_2_vec + +sentence2vec = load_sentence_2_vec() + + +class Trial_Dataset(data.Dataset): + def __init__(self, nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst): + self.nctid_lst = nctid_lst + self.label_lst = label_lst + self.smiles_lst = smiles_lst + self.icdcode_lst = icdcode_lst + self.criteria_lst = criteria_lst + + def __len__(self): + return len(self.nctid_lst) + + def __getitem__(self, index): + return self.nctid_lst[index], self.label_lst[index], self.smiles_lst[index], self.icdcode_lst[index], self.criteria_lst[index] + #### smiles_lst[index] is list of smiles + + +class Trial_Dataset_Complete(Trial_Dataset): + def __init__(self, nctid_lst, status_lst, why_stop_lst, label_lst, phase_lst, + diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst): + Trial_Dataset.__init__(self, nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst) + self.status_lst = status_lst + self.why_stop_lst = why_stop_lst + self.phase_lst = phase_lst + self.diseases_lst = diseases_lst + self.drugs_lst = drugs_lst + + def __getitem__(self, index): + return self.nctid_lst[index], self.status_lst[index], self.why_stop_lst[index], self.label_lst[index], self.phase_lst[index], \ + self.diseases_lst[index], self.icdcode_lst[index], self.drugs_lst[index], self.smiles_lst[index], self.criteria_lst[index] + + +class ADMET_Dataset(data.Dataset): + def __init__(self, smiles_lst, label_lst): + self.smiles_lst = smiles_lst + self.label_lst = label_lst + + def __len__(self): + return len(self.smiles_lst) + + def __getitem__(self, index): + return self.smiles_lst[index], self.label_lst[index] + +def admet_collate_fn(x): + smiles_lst = [i[0] for i in x] + label_vec = default_collate([int(i[1]) for i in x]) ### shape n, + return [smiles_lst, label_vec] + +def smiles_txt_to_lst(text): + """ + "['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=CC=C12', 'CNCCC=C1C2=CC=CC=C2CCC2=CC=CC=C12']" + """ + text = text[1:-1] + lst = [i.strip()[1:-1] for i in text.split(',')] + return lst + +def icdcode_text_2_lst_of_lst(text): + text = text[2:-2] + lst_lst = [] + for i in text.split('", "'): + i = i[1:-1] + lst_lst.append([j.strip()[1:-1] for j in i.split(',')]) + return lst_lst + +def trial_collate_fn(x): + nctid_lst = [i[0] for i in x] ### ['NCT00604461', ..., 'NCT00788957'] + label_vec = default_collate([int(i[1]) for i in x]) ### shape n, + smiles_lst = [smiles_txt_to_lst(i[2]) for i in x] + icdcode_lst = [icdcode_text_2_lst_of_lst(i[3]) for i in x] + criteria_lst = [protocol2feature(i[4], sentence2vec) for i in x] + return [nctid_lst, label_vec, smiles_lst, icdcode_lst, criteria_lst] + +def trial_complete_collate_fn(x): + nctid_lst = [i[0] for i in x] ### ['NCT00604461', ..., 'NCT00788957'] + status_lst = [i[1] for i in x] + why_stop_lst = [i[2] for i in x] + label_vec = default_collate([int(i[3]) for i in x]) ### shape n, + phase_lst = [i[4] for i in x] + diseases_lst = [i[5] for i in x] + icdcode_lst = [icdcode_text_2_lst_of_lst(i[6]) for i in x] + drugs_lst = [i[7] for i in x] + smiles_lst = [smiles_txt_to_lst(i[8]) for i in x] + criteria_lst = [protocol2feature(i[9], sentence2vec) for i in x] + return [nctid_lst, status_lst, why_stop_lst, label_vec, phase_lst, diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst] + +def csv_three_feature_2_dataloader(csvfile, shuffle, batch_size): + with open(csvfile, 'r') as csvfile: + rows = list(csv.reader(csvfile, delimiter=','))[1:] + ## nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria + nctid_lst = [row[0] for row in rows] + label_lst = [row[3] for row in rows] + icdcode_lst = [row[6] for row in rows] + drugs_lst = [row[7] for row in rows] + smiles_lst = [row[8] for row in rows] + criteria_lst = [row[9] for row in rows] + dataset = Trial_Dataset(nctid_lst, label_lst, smiles_lst, icdcode_lst, criteria_lst) + data_loader = data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, collate_fn = trial_collate_fn) + return data_loader + +def csv_three_feature_2_complete_dataloader(csvfile, shuffle, batch_size): + with open(csvfile, 'r') as csvfile: + rows = list(csv.reader(csvfile, delimiter=','))[1:] + nctid_lst = [row[0] for row in rows] + status_lst = [row[1] for row in rows] + why_stop_lst = [row[2] for row in rows] + label_lst = [row[3] for row in rows] + phase_lst = [row[4] for row in rows] + diseases_lst = [row[5] for row in rows] + icdcode_lst = [row[6] for row in rows] + drugs_lst = [row[7] for row in rows] + smiles_lst = [row[8] for row in rows] + new_drugs_lst, new_smiles_lst = [], [] + criteria_lst = [row[9] for row in rows] + dataset = Trial_Dataset_Complete(nctid_lst, status_lst, why_stop_lst, label_lst, phase_lst, + diseases_lst, icdcode_lst, drugs_lst, smiles_lst, criteria_lst) + data_loader = data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, collate_fn = trial_complete_collate_fn) + return data_loader + +def smiles_txt_to_2lst(smiles_txt_file): + with open(smiles_txt_file, 'r') as fin: + lines = fin.readlines() + smiles_lst = [line.split()[0] for line in lines] + label_lst = [int(line.split()[1]) for line in lines] + return smiles_lst, label_lst + +def generate_admet_dataloader_lst(batch_size): + datafolder = "data/ADMET/cooked/" + name_lst = ["absorption", 'distribution', 'metabolism', 'excretion', 'toxicity'] + dataloader_lst = [] + for i,name in enumerate(name_lst): + train_file = os.path.join(datafolder, name + '_train.txt') + test_file = os.path.join(datafolder, name +'_valid.txt') + train_smiles_lst, train_label_lst = smiles_txt_to_2lst(train_file) + test_smiles_lst, test_label_lst = smiles_txt_to_2lst(test_file) + train_dataset = ADMET_Dataset(smiles_lst = train_smiles_lst, label_lst = train_label_lst) + test_dataset = ADMET_Dataset(smiles_lst = test_smiles_lst, label_lst = test_label_lst) + train_dataloader = data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True) + test_dataloader = data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False) + dataloader_lst.append((train_dataloader, test_dataloader)) + return dataloader_lst + +# ## x is a list, len(x)=batch_size, x[i] is tuple, len(x[0])=5 +# def mpnn_feature_collate_func(x): +# return [torch.cat([x[j][i] for j in range(len(x))], 0) for i in range(len(x[0]))] + + +# def mpnn_collate_func(x): +# #print("len(x) is ", len(x)) ## batch_size +# #print("len(x[0]) is ", len(x[0])) ## 3--- data_process_loader.__getitem__ +# mpnn_feature = [i[0] for i in x] +# #print("len(mpnn_feature)", len(mpnn_feature), "len(mpnn_feature[0])", len(mpnn_feature[0])) +# mpnn_feature = mpnn_feature_collate_func(mpnn_feature) +# from torch.utils.data.dataloader import default_collate +# x_remain = [i[1:] for i in x] +# x_remain_collated = default_collate(x_remain) +# return [mpnn_feature] + x_remain_collated + + + + + + + + + + + + + + + + + + + + + +