--- a +++ b/create_data.py @@ -0,0 +1,139 @@ +import pandas as pd +import numpy as np +import os +import json,pickle +from collections import OrderedDict +from rdkit import Chem +from rdkit.Chem import MolFromSmiles +import networkx as nx +from utils import * + +def atom_features(atom): + return np.array(one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) + + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + + one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + + [atom.GetIsAromatic()]) + +def one_of_k_encoding(x, allowable_set): + if x not in allowable_set: + raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) + return list(map(lambda s: x == s, allowable_set)) + +def one_of_k_encoding_unk(x, allowable_set): + """Maps inputs not in the allowable set to the last element.""" + if x not in allowable_set: + x = allowable_set[-1] + return list(map(lambda s: x == s, allowable_set)) + +def smile_to_graph(smile): + mol = Chem.MolFromSmiles(smile) + + c_size = mol.GetNumAtoms() + + features = [] + for atom in mol.GetAtoms(): + feature = atom_features(atom) + features.append( feature / sum(feature) ) + + edges = [] + for bond in mol.GetBonds(): + edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + g = nx.Graph(edges).to_directed() + edge_index = [] + for e1, e2 in g.edges: + edge_index.append([e1, e2]) + + return c_size, features, edge_index + +def seq_cat(prot): + x = np.zeros(max_seq_len) + for i, ch in enumerate(prot[:max_seq_len]): + x[i] = seq_dict[ch] + return x + + +# from DeepDTA data +all_prots = [] +datasets = ['kiba','davis'] +for dataset in datasets: + print('convert data from DeepDTA for ', dataset) + fpath = 'data/' + dataset + '/' + train_fold = json.load(open(fpath + "folds/train_fold_setting1.txt")) + train_fold = [ee for e in train_fold for ee in e ] + valid_fold = json.load(open(fpath + "folds/test_fold_setting1.txt")) + ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict) + proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict) + affinity = pickle.load(open(fpath + "Y","rb"), encoding='latin1') + drugs = [] + prots = [] + for d in ligands.keys(): + lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]),isomericSmiles=True) + drugs.append(lg) + for t in proteins.keys(): + prots.append(proteins[t]) + if dataset == 'davis': + affinity = [-np.log10(y/1e9) for y in affinity] + affinity = np.asarray(affinity) + opts = ['train','test'] + for opt in opts: + rows, cols = np.where(np.isnan(affinity)==False) + if opt=='train': + rows,cols = rows[train_fold], cols[train_fold] + elif opt=='test': + rows,cols = rows[valid_fold], cols[valid_fold] + with open('data/' + dataset + '_' + opt + '.csv', 'w') as f: + f.write('compound_iso_smiles,target_sequence,affinity\n') + for pair_ind in range(len(rows)): + ls = [] + ls += [ drugs[rows[pair_ind]] ] + ls += [ prots[cols[pair_ind]] ] + ls += [ affinity[rows[pair_ind],cols[pair_ind]] ] + f.write(','.join(map(str,ls)) + '\n') + print('\ndataset:', dataset) + print('train_fold:', len(train_fold)) + print('test_fold:', len(valid_fold)) + print('len(set(drugs)),len(set(prots)):', len(set(drugs)),len(set(prots))) + all_prots += list(set(prots)) + + +seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ" +seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)} +seq_dict_len = len(seq_dict) +max_seq_len = 1000 + +compound_iso_smiles = [] +for dt_name in ['kiba','davis']: + opts = ['train','test'] + for opt in opts: + df = pd.read_csv('data/' + dt_name + '_' + opt + '.csv') + compound_iso_smiles += list( df['compound_iso_smiles'] ) +compound_iso_smiles = set(compound_iso_smiles) +smile_graph = {} +for smile in compound_iso_smiles: + g = smile_to_graph(smile) + smile_graph[smile] = g + +datasets = ['davis','kiba'] +# convert to PyTorch data format +for dataset in datasets: + processed_data_file_train = 'data/processed/' + dataset + '_train.pt' + processed_data_file_test = 'data/processed/' + dataset + '_test.pt' + if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))): + df = pd.read_csv('data/' + dataset + '_train.csv') + train_drugs, train_prots, train_Y = list(df['compound_iso_smiles']),list(df['target_sequence']),list(df['affinity']) + XT = [seq_cat(t) for t in train_prots] + train_drugs, train_prots, train_Y = np.asarray(train_drugs), np.asarray(XT), np.asarray(train_Y) + df = pd.read_csv('data/' + dataset + '_test.csv') + test_drugs, test_prots, test_Y = list(df['compound_iso_smiles']),list(df['target_sequence']),list(df['affinity']) + XT = [seq_cat(t) for t in test_prots] + test_drugs, test_prots, test_Y = np.asarray(test_drugs), np.asarray(XT), np.asarray(test_Y) + + # make data PyTorch Geometric ready + print('preparing ', dataset + '_train.pt in pytorch format!') + train_data = TestbedDataset(root='data', dataset=dataset+'_train', xd=train_drugs, xt=train_prots, y=train_Y,smile_graph=smile_graph) + print('preparing ', dataset + '_test.pt in pytorch format!') + test_data = TestbedDataset(root='data', dataset=dataset+'_test', xd=test_drugs, xt=test_prots, y=test_Y,smile_graph=smile_graph) + print(processed_data_file_train, ' and ', processed_data_file_test, ' have been created') + else: + print(processed_data_file_train, ' and ', processed_data_file_test, ' are already created')