|
a |
|
b/utils.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
from math import sqrt |
|
|
4 |
from scipy import stats |
|
|
5 |
from torch_geometric.data import InMemoryDataset, DataLoader |
|
|
6 |
from torch_geometric import data as DATA |
|
|
7 |
import torch |
|
|
8 |
|
|
|
9 |
class TestbedDataset(InMemoryDataset): |
|
|
10 |
def __init__(self, root='/tmp', dataset='davis', |
|
|
11 |
xd=None, xt=None, y=None, transform=None, |
|
|
12 |
pre_transform=None,smile_graph=None): |
|
|
13 |
|
|
|
14 |
#root is required for save preprocessed data, default is '/tmp' |
|
|
15 |
super(TestbedDataset, self).__init__(root, transform, pre_transform) |
|
|
16 |
# benchmark dataset, default = 'davis' |
|
|
17 |
self.dataset = dataset |
|
|
18 |
if os.path.isfile(self.processed_paths[0]): |
|
|
19 |
print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0])) |
|
|
20 |
self.data, self.slices = torch.load(self.processed_paths[0]) |
|
|
21 |
else: |
|
|
22 |
print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0])) |
|
|
23 |
self.process(xd, xt, y,smile_graph) |
|
|
24 |
self.data, self.slices = torch.load(self.processed_paths[0]) |
|
|
25 |
|
|
|
26 |
@property |
|
|
27 |
def raw_file_names(self): |
|
|
28 |
pass |
|
|
29 |
#return ['some_file_1', 'some_file_2', ...] |
|
|
30 |
|
|
|
31 |
@property |
|
|
32 |
def processed_file_names(self): |
|
|
33 |
return [self.dataset + '.pt'] |
|
|
34 |
|
|
|
35 |
def download(self): |
|
|
36 |
# Download to `self.raw_dir`. |
|
|
37 |
pass |
|
|
38 |
|
|
|
39 |
def _download(self): |
|
|
40 |
pass |
|
|
41 |
|
|
|
42 |
def _process(self): |
|
|
43 |
if not os.path.exists(self.processed_dir): |
|
|
44 |
os.makedirs(self.processed_dir) |
|
|
45 |
|
|
|
46 |
# Customize the process method to fit the task of drug-target affinity prediction |
|
|
47 |
# Inputs: |
|
|
48 |
# XD - list of SMILES, XT: list of encoded target (categorical or one-hot), |
|
|
49 |
# Y: list of labels (i.e. affinity) |
|
|
50 |
# Return: PyTorch-Geometric format processed data |
|
|
51 |
def process(self, xd, xt, y,smile_graph): |
|
|
52 |
assert (len(xd) == len(xt) and len(xt) == len(y)), "The three lists must be the same length!" |
|
|
53 |
data_list = [] |
|
|
54 |
data_len = len(xd) |
|
|
55 |
for i in range(data_len): |
|
|
56 |
print('Converting SMILES to graph: {}/{}'.format(i+1, data_len)) |
|
|
57 |
smiles = xd[i] |
|
|
58 |
target = xt[i] |
|
|
59 |
labels = y[i] |
|
|
60 |
# convert SMILES to molecular representation using rdkit |
|
|
61 |
c_size, features, edge_index = smile_graph[smiles] |
|
|
62 |
# make the graph ready for PyTorch Geometrics GCN algorithms: |
|
|
63 |
GCNData = DATA.Data(x=torch.Tensor(features), |
|
|
64 |
edge_index=torch.LongTensor(edge_index).transpose(1, 0), |
|
|
65 |
y=torch.FloatTensor([labels])) |
|
|
66 |
GCNData.target = torch.LongTensor([target]) |
|
|
67 |
GCNData.__setitem__('c_size', torch.LongTensor([c_size])) |
|
|
68 |
# append graph, label and target sequence to data list |
|
|
69 |
data_list.append(GCNData) |
|
|
70 |
|
|
|
71 |
if self.pre_filter is not None: |
|
|
72 |
data_list = [data for data in data_list if self.pre_filter(data)] |
|
|
73 |
|
|
|
74 |
if self.pre_transform is not None: |
|
|
75 |
data_list = [self.pre_transform(data) for data in data_list] |
|
|
76 |
print('Graph construction done. Saving to file.') |
|
|
77 |
data, slices = self.collate(data_list) |
|
|
78 |
# save preprocessed data: |
|
|
79 |
torch.save((data, slices), self.processed_paths[0]) |
|
|
80 |
|
|
|
81 |
def rmse(y,f): |
|
|
82 |
rmse = sqrt(((y - f)**2).mean(axis=0)) |
|
|
83 |
return rmse |
|
|
84 |
def mse(y,f): |
|
|
85 |
mse = ((y - f)**2).mean(axis=0) |
|
|
86 |
return mse |
|
|
87 |
def pearson(y,f): |
|
|
88 |
rp = np.corrcoef(y, f)[0,1] |
|
|
89 |
return rp |
|
|
90 |
def spearman(y,f): |
|
|
91 |
rs = stats.spearmanr(y, f)[0] |
|
|
92 |
return rs |
|
|
93 |
def ci(y,f): |
|
|
94 |
ind = np.argsort(y) |
|
|
95 |
y = y[ind] |
|
|
96 |
f = f[ind] |
|
|
97 |
i = len(y)-1 |
|
|
98 |
j = i-1 |
|
|
99 |
z = 0.0 |
|
|
100 |
S = 0.0 |
|
|
101 |
while i > 0: |
|
|
102 |
while j >= 0: |
|
|
103 |
if y[i] > y[j]: |
|
|
104 |
z = z+1 |
|
|
105 |
u = f[i] - f[j] |
|
|
106 |
if u > 0: |
|
|
107 |
S = S + 1 |
|
|
108 |
elif u == 0: |
|
|
109 |
S = S + 0.5 |
|
|
110 |
j = j - 1 |
|
|
111 |
i = i - 1 |
|
|
112 |
j = i-1 |
|
|
113 |
ci = S/z |
|
|
114 |
return ci |