--- a +++ b/HINT/molecule_encode.py @@ -0,0 +1,449 @@ +''' +input: + smiles batch + + + +utility + 1. graph MPN + 2. smiles + 3. morgan feature + +output: + 1. embedding batch + + + +deeppurpose + DDI + encoders model + +to do + lst -> dataloader -> feature -> model + + + mpnn's feature -> collate -> model + +''' + +import csv +from tqdm import tqdm +import numpy as np +from copy import deepcopy +import matplotlib.pyplot as plt + +import rdkit +import rdkit.Chem as Chem +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.info') +RDLogger.DisableLog('rdApp.*') +# from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions +import torch +torch.manual_seed(0) +from torch import nn +from torch.autograd import Variable +import torch.nn.functional as F +from torch.utils import data #### data.Dataset +# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +from HINT.module import Highway + +def get_drugbank_smiles_lst(): + drugfile = 'data/drugbank_drugs_info.csv' + with open(drugfile, 'r') as csvfile: + rows = list(csv.reader(csvfile, delimiter = ','))[1:] + return [row[27] for row in rows] + +def txt_to_lst(text): + """ + "['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=CC=C12', 'CNCCC=C1C2=CC=CC=C2CCC2=CC=CC=C12']" + """ + text = text[1:-1] + lst = [i.strip()[1:-1] for i in text.split(',')] + return lst + +def get_cooked_data_smiles_lst(): + cooked_file = 'data/raw_data.csv' + with open(cooked_file, 'r') as csvfile: + rows = list(csv.reader(csvfile, delimiter = ','))[1:] + smiles_lst = [row[8] for row in rows] + smiles_lst = list(map(txt_to_lst, smiles_lst)) + from functools import reduce + smiles_lst = list(reduce(lambda x,y:x+y, smiles_lst)) + smiles_lst = list(set(smiles_lst)) + # print(len(smiles_lst)) + return smiles_lst + + + +def create_var(tensor, requires_grad=None): + if requires_grad is None: + return Variable(tensor) + else: + return Variable(tensor, requires_grad=requires_grad) + +def index_select_ND(source, dim, index): + index_size = index.size() + suffix_dim = source.size()[1:] + final_size = index_size + suffix_dim + target = source.index_select(dim, index.view(-1)) + return target.view(final_size) + + +def get_mol(smiles): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + Chem.Kekulize(mol) + return mol + +ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] +ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1 +BOND_FDIM = 5 + 6 +MAX_NB = 6 +### basic setting from https://github.com/wengong-jin/iclr19-graph2graph/blob/master/fast_jtnn/mpn.py + +def onek_encoding_unk(x, allowable_set): + if x not in allowable_set: + x = allowable_set[-1] + return list(map(lambda s: x == s, allowable_set)) + +def atom_features(atom): + return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) + + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) + + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) + + onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3]) + + [atom.GetIsAromatic()]) + +def bond_features(bond): + bt = bond.GetBondType() + stereo = int(bond.GetStereo()) + fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()] + fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5]) + return torch.Tensor(fbond + fstereo) + +def smiles2mpnnfeature(smiles): + ## from mpn.py::tensorize + ''' + data-flow: + data_process(): apply(smiles2mpnnfeature) + DBTA: train(): data.DataLoader(data_process_loader()) + mpnn_collate_func() + ''' + padding = torch.zeros(ATOM_FDIM + BOND_FDIM) + fatoms, fbonds = [], [padding] + in_bonds,all_bonds = [], [(-1,-1)] + mol = get_mol(smiles) + if mol is not None: + n_atoms = mol.GetNumAtoms() + for atom in mol.GetAtoms(): + fatoms.append( atom_features(atom)) + in_bonds.append([]) + + for bond in mol.GetBonds(): + a1 = bond.GetBeginAtom() + a2 = bond.GetEndAtom() + x = a1.GetIdx() + y = a2.GetIdx() + + b = len(all_bonds) + all_bonds.append((x,y)) + fbonds.append( torch.cat([fatoms[x], bond_features(bond)], 0) ) + in_bonds[y].append(b) + + b = len(all_bonds) + all_bonds.append((y,x)) + fbonds.append( torch.cat([fatoms[y], bond_features(bond)], 0) ) + in_bonds[x].append(b) + + total_bonds = len(all_bonds) + fatoms = torch.stack(fatoms, 0) + fbonds = torch.stack(fbonds, 0) + agraph = torch.zeros(n_atoms,MAX_NB).long() + bgraph = torch.zeros(total_bonds,MAX_NB).long() + for a in range(n_atoms): + for i,b in enumerate(in_bonds[a]): + agraph[a,i] = b + + for b1 in range(1, total_bonds): + x,y = all_bonds[b1] + for i,b2 in enumerate(in_bonds[x]): + if all_bonds[b2][0] != y: + bgraph[b1,i] = b2 + else: + # print('Molecules not found and change to zero vectors..') + fatoms = torch.zeros(0,39) + fbonds = torch.zeros(0,50) + agraph = torch.zeros(0,6) + bgraph = torch.zeros(0,6) + Natom, Nbond = fatoms.shape[0], fbonds.shape[0] + shape_tensor = torch.Tensor([Natom, Nbond]).view(1,-1) + return [fatoms.float(), fbonds.float(), agraph.float(), bgraph.float(), shape_tensor] + + +class smiles_dataset(data.Dataset): + def __init__(self, smiles_lst, label_lst): + self.smiles_lst = smiles_lst + self.label_lst = label_lst + + def __len__(self): + return len(self.smiles_lst) + + def __getitem__(self, index): + smiles = self.smiles_lst[index] + label = self.label_lst[index] + smiles_feature = smiles2mpnnfeature(smiles) + return smiles_feature, label + +## DTI.py --> collate + +## x is a list, len(x)=batch_size, x[i] is tuple, len(x[0])=5 +def mpnn_feature_collate_func(x): + return [torch.cat([x[j][i] for j in range(len(x))], 0) for i in range(len(x[0]))] + +def mpnn_collate_func(x): + #print("len(x) is ", len(x)) ## batch_size + #print("len(x[0]) is ", len(x[0])) ## 3--- data_process_loader.__getitem__ + mpnn_feature = [i[0] for i in x] + #print("len(mpnn_feature)", len(mpnn_feature), "len(mpnn_feature[0])", len(mpnn_feature[0])) + mpnn_feature = mpnn_feature_collate_func(mpnn_feature) + from torch.utils.data.dataloader import default_collate + x_remain = [i[1:] for i in x] + x_remain_collated = default_collate(x_remain) + return [mpnn_feature] + x_remain_collated + + +def data_loader(): + smiles_lst = get_cooked_data_smiles_lst() + label_lst = [1 for i in range(len(smiles_lst))] + dataset = smiles_dataset(smiles_lst, label_lst) + dataloader = data.DataLoader(dataset, batch_size=32, collate_fn = mpnn_collate_func, ) + return dataloader + + +class MPNN(nn.Sequential): + def __init__(self, mpnn_hidden_size, mpnn_depth, device): + super(MPNN, self).__init__() + self.mpnn_hidden_size = mpnn_hidden_size + self.mpnn_depth = mpnn_depth + + self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, self.mpnn_hidden_size, bias=False) + self.W_h = nn.Linear(self.mpnn_hidden_size, self.mpnn_hidden_size, bias=False) + self.W_o = nn.Linear(ATOM_FDIM + self.mpnn_hidden_size, self.mpnn_hidden_size) + + self.device = device + self = self.to(self.device) + + def set_device(self, device): + self.device = device + + + @property + def embedding_size(self): + return self.mpnn_hidden_size + + ### forward single molecule sequentially. + def feature_forward(self, feature): + ''' + batch_size == 1 + feature: utils.smiles2mpnnfeature + ''' + fatoms, fbonds, agraph, bgraph, atoms_bonds = feature + agraph = agraph.long() + bgraph = bgraph.long() + #print(fatoms.shape, fbonds.shape, agraph.shape, bgraph.shape, atoms_bonds.shape) + atoms_bonds = atoms_bonds.long() + batch_size = atoms_bonds.shape[0] + N_atoms, N_bonds = 0, 0 + embeddings = [] + for i in range(batch_size): + n_a = atoms_bonds[i,0].item() + n_b = atoms_bonds[i,1].item() + if (n_a == 0): + embed = create_var(torch.zeros(1, self.mpnn_hidden_size)) + embeddings.append(embed.to(self.device)) + continue + sub_fatoms = fatoms[N_atoms:N_atoms+n_a,:].to(self.device) + sub_fbonds = fbonds[N_bonds:N_bonds+n_b,:].to(self.device) + sub_agraph = agraph[N_atoms:N_atoms+n_a,:].to(self.device) + sub_bgraph = bgraph[N_bonds:N_bonds+n_b,:].to(self.device) + embed = self.single_feature_forward(sub_fatoms, sub_fbonds, sub_agraph, sub_bgraph) + embed = embed.to(self.device) + embeddings.append(embed) + N_atoms += n_a + N_bonds += n_b + if len(embeddings)==0: + return None + else: + return torch.cat(embeddings, 0) + + def single_feature_forward(self, fatoms, fbonds, agraph, bgraph): + ''' + fatoms: (x, 39) + fbonds: (y, 50) + agraph: (x, 6) + bgraph: (y,6) + ''' + ### invalid molecule + if fatoms.shape[0] == 0: + return create_var(torch.zeros(1, self.mpnn_hidden_size).to(self.device)) + agraph = agraph.long() + bgraph = bgraph.long() + fatoms = create_var(fatoms).to(self.device) + fbonds = create_var(fbonds).to(self.device) + agraph = create_var(agraph).to(self.device) + bgraph = create_var(bgraph).to(self.device) + + binput = self.W_i(fbonds) + message = F.relu(binput) + #print("shapes", fbonds.shape, binput.shape, message.shape) + for i in range(self.mpnn_depth - 1): + nei_message = index_select_ND(message, 0, bgraph) + nei_message = nei_message.sum(dim=1) + nei_message = self.W_h(nei_message) + message = F.relu(binput + nei_message) + + nei_message = index_select_ND(message, 0, agraph) + nei_message = nei_message.sum(dim=1) + ainput = torch.cat([fatoms, nei_message], dim=1) + atom_hiddens = F.relu(self.W_o(ainput)) + return torch.mean(atom_hiddens, 0).view(1,-1) + + + def forward_single_smiles(self, smiles): + fatoms, fbonds, agraph, bgraph, _ = smiles2mpnnfeature(smiles) + embed = self.single_feature_forward(fatoms, fbonds, agraph, bgraph).view(1,-1) + return embed + + def forward_smiles_lst(self, smiles_lst): + embed_lst = [self.forward_single_smiles(smiles) for smiles in smiles_lst] + embed_all = torch.cat(embed_lst, 0) + return embed_all + + def forward_smiles_lst_average(self, smiles_lst): + embed_all = self.forward_smiles_lst(smiles_lst) + embed_avg = torch.mean(embed_all, 0).view(1,-1) + return embed_avg + + + def forward_smiles_lst_lst(self, smiles_lst_lst): + embed_lst = [self.forward_smiles_lst_average(smiles_lst) for smiles_lst in smiles_lst_lst] + embed_all = torch.cat(embed_lst, 0) #### n,dim + return embed_all + + + +class ADMET(nn.Sequential): + + def __init__(self, molecule_encoder, highway_num, device, + epoch, lr, weight_decay, save_name): + super(ADMET, self).__init__() + self.molecule_encoder = molecule_encoder + self.embedding_size = self.molecule_encoder.embedding_size + self.highway_num = highway_num + self.highway_nn_lst = nn.ModuleList([Highway(size = self.embedding_size, num_layers = self.highway_num) for i in range(5)]) + self.fc_output_lst = nn.ModuleList([nn.Linear(self.embedding_size, 1) for i in range(5)]) + self.f = F.relu + self.loss = nn.BCEWithLogitsLoss() + + self.epoch = epoch + self.lr = lr + self.weight_decay = weight_decay + self.save_name = save_name + + self.device = device + self = self.to(device) + + def set_device(self, device): + self.device = device + self.molecule_encoder.set_device(device) + + + def forward_smiles_lst_embedding(self, smiles_lst, idx): + embed_all = self.molecule_encoder.forward_smiles_lst(smiles_lst) + output = self.highway_nn_lst[idx](embed_all) + return output + + def forward_embedding_to_pred(self, embeded, idx): + return self.fc_output_lst[idx](embeded) + + def forward_smiles_lst_pred(self, smiles_lst, idx): + embeded = self.forward_smiles_lst_embedding(smiles_lst, idx) + fc_output = self.forward_embedding_to_pred(embeded, idx) + return fc_output + + def test(self, dataloader_lst, return_loss = True): + loss_lst = [] + for idx in range(1): + single_loss_lst = [] + for smiles_lst, label_vec in dataloader_lst[idx]: + output = self.forward_smiles_lst_pred(smiles_lst, idx).view(-1) + loss = self.loss(output, label_vec.to(self.device).float()) + single_loss_lst.append(loss.item()) + loss_lst.append(np.mean(single_loss_lst)) + return np.mean(loss_lst) + + def train(self, train_loader_lst, valid_loader_lst): + opt = torch.optim.Adam(self.parameters(), lr = self.lr, weight_decay = self.weight_decay) + train_loss_record = [] + valid_loss = self.test(valid_loader_lst, return_loss=True) + valid_loss_record = [valid_loss] + best_valid_loss = valid_loss + best_model = deepcopy(self) + for ep in tqdm(range(self.epoch)): + data_iterator_lst = [iter(train_loader_lst[idx]) for idx in range(5)] + try: + while True: + for idx in range(1): + smiles_lst, label_vec = next(data_iterator_lst[idx]) + output = self.forward_smiles_lst_pred(smiles_lst, idx).view(-1) + loss = self.loss(output, label_vec.float()) + opt.zero_grad() + loss.backward() + opt.step() + except: + pass + valid_loss = self.test(valid_loader_lst, return_loss = True) + valid_loss_record.append(valid_loss) + + if valid_loss < best_valid_loss: + best_valid_loss = valid_loss + best_model = deepcopy(self) + + self = deepcopy(best_model) + + + + + +if __name__ == "__main__": + model = MPNN(mpnn_hidden_size = 50, mpnn_depth = 3) + dataloader = data_loader() + for smiles_feature, labels in dataloader: + embedding = model(smiles_feature) + print(embedding.shape) + + # smiles_lst = get_cooked_data_smiles_lst() + # valid_cnt, cnt = 0, 0 + # for i,smiles in tqdm(enumerate(smiles_lst)): + # feature = smiles2mpnnfeature(smiles) + # if feature is not None: + # valid_cnt += 1 + # if i%100==0: + # print("valid rate is", str(valid_cnt/(i+1))) + + ### single molecule forward + # for smiles in smiles_lst: + # fatoms, fbonds, agraph, bgraph, abshape = smiles2mpnnfeature(smiles) + # embedding = model.single_molecule_forward(fatoms, fbonds, agraph, bgraph) + # print(embedding.shape) + + + + + + + + + + +