|
a |
|
b/create_data.py |
|
|
1 |
import pandas as pd |
|
|
2 |
import numpy as np |
|
|
3 |
import os |
|
|
4 |
import json,pickle |
|
|
5 |
from collections import OrderedDict |
|
|
6 |
from rdkit import Chem |
|
|
7 |
from rdkit.Chem import MolFromSmiles |
|
|
8 |
import networkx as nx |
|
|
9 |
from utils import * |
|
|
10 |
|
|
|
11 |
def atom_features(atom): |
|
|
12 |
return np.array(one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) + |
|
|
13 |
one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + |
|
|
14 |
one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + |
|
|
15 |
one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + |
|
|
16 |
[atom.GetIsAromatic()]) |
|
|
17 |
|
|
|
18 |
def one_of_k_encoding(x, allowable_set): |
|
|
19 |
if x not in allowable_set: |
|
|
20 |
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) |
|
|
21 |
return list(map(lambda s: x == s, allowable_set)) |
|
|
22 |
|
|
|
23 |
def one_of_k_encoding_unk(x, allowable_set): |
|
|
24 |
"""Maps inputs not in the allowable set to the last element.""" |
|
|
25 |
if x not in allowable_set: |
|
|
26 |
x = allowable_set[-1] |
|
|
27 |
return list(map(lambda s: x == s, allowable_set)) |
|
|
28 |
|
|
|
29 |
def smile_to_graph(smile): |
|
|
30 |
mol = Chem.MolFromSmiles(smile) |
|
|
31 |
|
|
|
32 |
c_size = mol.GetNumAtoms() |
|
|
33 |
|
|
|
34 |
features = [] |
|
|
35 |
for atom in mol.GetAtoms(): |
|
|
36 |
feature = atom_features(atom) |
|
|
37 |
features.append( feature / sum(feature) ) |
|
|
38 |
|
|
|
39 |
edges = [] |
|
|
40 |
for bond in mol.GetBonds(): |
|
|
41 |
edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) |
|
|
42 |
g = nx.Graph(edges).to_directed() |
|
|
43 |
edge_index = [] |
|
|
44 |
for e1, e2 in g.edges: |
|
|
45 |
edge_index.append([e1, e2]) |
|
|
46 |
|
|
|
47 |
return c_size, features, edge_index |
|
|
48 |
|
|
|
49 |
def seq_cat(prot): |
|
|
50 |
x = np.zeros(max_seq_len) |
|
|
51 |
for i, ch in enumerate(prot[:max_seq_len]): |
|
|
52 |
x[i] = seq_dict[ch] |
|
|
53 |
return x |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
# from DeepDTA data |
|
|
57 |
all_prots = [] |
|
|
58 |
datasets = ['kiba','davis'] |
|
|
59 |
for dataset in datasets: |
|
|
60 |
print('convert data from DeepDTA for ', dataset) |
|
|
61 |
fpath = 'data/' + dataset + '/' |
|
|
62 |
train_fold = json.load(open(fpath + "folds/train_fold_setting1.txt")) |
|
|
63 |
train_fold = [ee for e in train_fold for ee in e ] |
|
|
64 |
valid_fold = json.load(open(fpath + "folds/test_fold_setting1.txt")) |
|
|
65 |
ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict) |
|
|
66 |
proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict) |
|
|
67 |
affinity = pickle.load(open(fpath + "Y","rb"), encoding='latin1') |
|
|
68 |
drugs = [] |
|
|
69 |
prots = [] |
|
|
70 |
for d in ligands.keys(): |
|
|
71 |
lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]),isomericSmiles=True) |
|
|
72 |
drugs.append(lg) |
|
|
73 |
for t in proteins.keys(): |
|
|
74 |
prots.append(proteins[t]) |
|
|
75 |
if dataset == 'davis': |
|
|
76 |
affinity = [-np.log10(y/1e9) for y in affinity] |
|
|
77 |
affinity = np.asarray(affinity) |
|
|
78 |
opts = ['train','test'] |
|
|
79 |
for opt in opts: |
|
|
80 |
rows, cols = np.where(np.isnan(affinity)==False) |
|
|
81 |
if opt=='train': |
|
|
82 |
rows,cols = rows[train_fold], cols[train_fold] |
|
|
83 |
elif opt=='test': |
|
|
84 |
rows,cols = rows[valid_fold], cols[valid_fold] |
|
|
85 |
with open('data/' + dataset + '_' + opt + '.csv', 'w') as f: |
|
|
86 |
f.write('compound_iso_smiles,target_sequence,affinity\n') |
|
|
87 |
for pair_ind in range(len(rows)): |
|
|
88 |
ls = [] |
|
|
89 |
ls += [ drugs[rows[pair_ind]] ] |
|
|
90 |
ls += [ prots[cols[pair_ind]] ] |
|
|
91 |
ls += [ affinity[rows[pair_ind],cols[pair_ind]] ] |
|
|
92 |
f.write(','.join(map(str,ls)) + '\n') |
|
|
93 |
print('\ndataset:', dataset) |
|
|
94 |
print('train_fold:', len(train_fold)) |
|
|
95 |
print('test_fold:', len(valid_fold)) |
|
|
96 |
print('len(set(drugs)),len(set(prots)):', len(set(drugs)),len(set(prots))) |
|
|
97 |
all_prots += list(set(prots)) |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ" |
|
|
101 |
seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)} |
|
|
102 |
seq_dict_len = len(seq_dict) |
|
|
103 |
max_seq_len = 1000 |
|
|
104 |
|
|
|
105 |
compound_iso_smiles = [] |
|
|
106 |
for dt_name in ['kiba','davis']: |
|
|
107 |
opts = ['train','test'] |
|
|
108 |
for opt in opts: |
|
|
109 |
df = pd.read_csv('data/' + dt_name + '_' + opt + '.csv') |
|
|
110 |
compound_iso_smiles += list( df['compound_iso_smiles'] ) |
|
|
111 |
compound_iso_smiles = set(compound_iso_smiles) |
|
|
112 |
smile_graph = {} |
|
|
113 |
for smile in compound_iso_smiles: |
|
|
114 |
g = smile_to_graph(smile) |
|
|
115 |
smile_graph[smile] = g |
|
|
116 |
|
|
|
117 |
datasets = ['davis','kiba'] |
|
|
118 |
# convert to PyTorch data format |
|
|
119 |
for dataset in datasets: |
|
|
120 |
processed_data_file_train = 'data/processed/' + dataset + '_train.pt' |
|
|
121 |
processed_data_file_test = 'data/processed/' + dataset + '_test.pt' |
|
|
122 |
if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))): |
|
|
123 |
df = pd.read_csv('data/' + dataset + '_train.csv') |
|
|
124 |
train_drugs, train_prots, train_Y = list(df['compound_iso_smiles']),list(df['target_sequence']),list(df['affinity']) |
|
|
125 |
XT = [seq_cat(t) for t in train_prots] |
|
|
126 |
train_drugs, train_prots, train_Y = np.asarray(train_drugs), np.asarray(XT), np.asarray(train_Y) |
|
|
127 |
df = pd.read_csv('data/' + dataset + '_test.csv') |
|
|
128 |
test_drugs, test_prots, test_Y = list(df['compound_iso_smiles']),list(df['target_sequence']),list(df['affinity']) |
|
|
129 |
XT = [seq_cat(t) for t in test_prots] |
|
|
130 |
test_drugs, test_prots, test_Y = np.asarray(test_drugs), np.asarray(XT), np.asarray(test_Y) |
|
|
131 |
|
|
|
132 |
# make data PyTorch Geometric ready |
|
|
133 |
print('preparing ', dataset + '_train.pt in pytorch format!') |
|
|
134 |
train_data = TestbedDataset(root='data', dataset=dataset+'_train', xd=train_drugs, xt=train_prots, y=train_Y,smile_graph=smile_graph) |
|
|
135 |
print('preparing ', dataset + '_test.pt in pytorch format!') |
|
|
136 |
test_data = TestbedDataset(root='data', dataset=dataset+'_test', xd=test_drugs, xt=test_prots, y=test_Y,smile_graph=smile_graph) |
|
|
137 |
print(processed_data_file_train, ' and ', processed_data_file_test, ' have been created') |
|
|
138 |
else: |
|
|
139 |
print(processed_data_file_train, ' and ', processed_data_file_test, ' are already created') |