Diff of /create_data.py [000000] .. [64be90]

Switch to unified view

a b/create_data.py
1
import pandas as pd
2
import numpy as np
3
import os
4
import json,pickle
5
from collections import OrderedDict
6
from rdkit import Chem
7
from rdkit.Chem import MolFromSmiles
8
import networkx as nx
9
from utils import *
10
11
def atom_features(atom):
12
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
13
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) +
14
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) +
15
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) +
16
                    [atom.GetIsAromatic()])
17
18
def one_of_k_encoding(x, allowable_set):
19
    if x not in allowable_set:
20
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
21
    return list(map(lambda s: x == s, allowable_set))
22
23
def one_of_k_encoding_unk(x, allowable_set):
24
    """Maps inputs not in the allowable set to the last element."""
25
    if x not in allowable_set:
26
        x = allowable_set[-1]
27
    return list(map(lambda s: x == s, allowable_set))
28
29
def smile_to_graph(smile):
30
    mol = Chem.MolFromSmiles(smile)
31
    
32
    c_size = mol.GetNumAtoms()
33
    
34
    features = []
35
    for atom in mol.GetAtoms():
36
        feature = atom_features(atom)
37
        features.append( feature / sum(feature) )
38
39
    edges = []
40
    for bond in mol.GetBonds():
41
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
42
    g = nx.Graph(edges).to_directed()
43
    edge_index = []
44
    for e1, e2 in g.edges:
45
        edge_index.append([e1, e2])
46
        
47
    return c_size, features, edge_index
48
49
def seq_cat(prot):
50
    x = np.zeros(max_seq_len)
51
    for i, ch in enumerate(prot[:max_seq_len]): 
52
        x[i] = seq_dict[ch]
53
    return x  
54
55
56
# from DeepDTA data
57
all_prots = []
58
datasets = ['kiba','davis']
59
for dataset in datasets:
60
    print('convert data from DeepDTA for ', dataset)
61
    fpath = 'data/' + dataset + '/'
62
    train_fold = json.load(open(fpath + "folds/train_fold_setting1.txt"))
63
    train_fold = [ee for e in train_fold for ee in e ]
64
    valid_fold = json.load(open(fpath + "folds/test_fold_setting1.txt"))
65
    ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict)
66
    proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict)
67
    affinity = pickle.load(open(fpath + "Y","rb"), encoding='latin1')
68
    drugs = []
69
    prots = []
70
    for d in ligands.keys():
71
        lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]),isomericSmiles=True)
72
        drugs.append(lg)
73
    for t in proteins.keys():
74
        prots.append(proteins[t])
75
    if dataset == 'davis':
76
        affinity = [-np.log10(y/1e9) for y in affinity]
77
    affinity = np.asarray(affinity)
78
    opts = ['train','test']
79
    for opt in opts:
80
        rows, cols = np.where(np.isnan(affinity)==False)  
81
        if opt=='train':
82
            rows,cols = rows[train_fold], cols[train_fold]
83
        elif opt=='test':
84
            rows,cols = rows[valid_fold], cols[valid_fold]
85
        with open('data/' + dataset + '_' + opt + '.csv', 'w') as f:
86
            f.write('compound_iso_smiles,target_sequence,affinity\n')
87
            for pair_ind in range(len(rows)):
88
                ls = []
89
                ls += [ drugs[rows[pair_ind]]  ]
90
                ls += [ prots[cols[pair_ind]]  ]
91
                ls += [ affinity[rows[pair_ind],cols[pair_ind]]  ]
92
                f.write(','.join(map(str,ls)) + '\n')       
93
    print('\ndataset:', dataset)
94
    print('train_fold:', len(train_fold))
95
    print('test_fold:', len(valid_fold))
96
    print('len(set(drugs)),len(set(prots)):', len(set(drugs)),len(set(prots)))
97
    all_prots += list(set(prots))
98
    
99
    
100
seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ"
101
seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)}
102
seq_dict_len = len(seq_dict)
103
max_seq_len = 1000
104
105
compound_iso_smiles = []
106
for dt_name in ['kiba','davis']:
107
    opts = ['train','test']
108
    for opt in opts:
109
        df = pd.read_csv('data/' + dt_name + '_' + opt + '.csv')
110
        compound_iso_smiles += list( df['compound_iso_smiles'] )
111
compound_iso_smiles = set(compound_iso_smiles)
112
smile_graph = {}
113
for smile in compound_iso_smiles:
114
    g = smile_to_graph(smile)
115
    smile_graph[smile] = g
116
117
datasets = ['davis','kiba']
118
# convert to PyTorch data format
119
for dataset in datasets:
120
    processed_data_file_train = 'data/processed/' + dataset + '_train.pt'
121
    processed_data_file_test = 'data/processed/' + dataset + '_test.pt'
122
    if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))):
123
        df = pd.read_csv('data/' + dataset + '_train.csv')
124
        train_drugs, train_prots,  train_Y = list(df['compound_iso_smiles']),list(df['target_sequence']),list(df['affinity'])
125
        XT = [seq_cat(t) for t in train_prots]
126
        train_drugs, train_prots,  train_Y = np.asarray(train_drugs), np.asarray(XT), np.asarray(train_Y)
127
        df = pd.read_csv('data/' + dataset + '_test.csv')
128
        test_drugs, test_prots,  test_Y = list(df['compound_iso_smiles']),list(df['target_sequence']),list(df['affinity'])
129
        XT = [seq_cat(t) for t in test_prots]
130
        test_drugs, test_prots,  test_Y = np.asarray(test_drugs), np.asarray(XT), np.asarray(test_Y)
131
132
        # make data PyTorch Geometric ready
133
        print('preparing ', dataset + '_train.pt in pytorch format!')
134
        train_data = TestbedDataset(root='data', dataset=dataset+'_train', xd=train_drugs, xt=train_prots, y=train_Y,smile_graph=smile_graph)
135
        print('preparing ', dataset + '_test.pt in pytorch format!')
136
        test_data = TestbedDataset(root='data', dataset=dataset+'_test', xd=test_drugs, xt=test_prots, y=test_Y,smile_graph=smile_graph)
137
        print(processed_data_file_train, ' and ', processed_data_file_test, ' have been created')        
138
    else:
139
        print(processed_data_file_train, ' and ', processed_data_file_test, ' are already created')