Diff of /utils.py [000000] .. [64be90]

Switch to unified view

a b/utils.py
1
import os
2
import numpy as np
3
from math import sqrt
4
from scipy import stats
5
from torch_geometric.data import InMemoryDataset, DataLoader
6
from torch_geometric import data as DATA
7
import torch
8
9
class TestbedDataset(InMemoryDataset):
10
    def __init__(self, root='/tmp', dataset='davis', 
11
                 xd=None, xt=None, y=None, transform=None,
12
                 pre_transform=None,smile_graph=None):
13
14
        #root is required for save preprocessed data, default is '/tmp'
15
        super(TestbedDataset, self).__init__(root, transform, pre_transform)
16
        # benchmark dataset, default = 'davis'
17
        self.dataset = dataset
18
        if os.path.isfile(self.processed_paths[0]):
19
            print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0]))
20
            self.data, self.slices = torch.load(self.processed_paths[0])
21
        else:
22
            print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0]))
23
            self.process(xd, xt, y,smile_graph)
24
            self.data, self.slices = torch.load(self.processed_paths[0])
25
26
    @property
27
    def raw_file_names(self):
28
        pass
29
        #return ['some_file_1', 'some_file_2', ...]
30
31
    @property
32
    def processed_file_names(self):
33
        return [self.dataset + '.pt']
34
35
    def download(self):
36
        # Download to `self.raw_dir`.
37
        pass
38
39
    def _download(self):
40
        pass
41
42
    def _process(self):
43
        if not os.path.exists(self.processed_dir):
44
            os.makedirs(self.processed_dir)
45
46
    # Customize the process method to fit the task of drug-target affinity prediction
47
    # Inputs:
48
    # XD - list of SMILES, XT: list of encoded target (categorical or one-hot),
49
    # Y: list of labels (i.e. affinity)
50
    # Return: PyTorch-Geometric format processed data
51
    def process(self, xd, xt, y,smile_graph):
52
        assert (len(xd) == len(xt) and len(xt) == len(y)), "The three lists must be the same length!"
53
        data_list = []
54
        data_len = len(xd)
55
        for i in range(data_len):
56
            print('Converting SMILES to graph: {}/{}'.format(i+1, data_len))
57
            smiles = xd[i]
58
            target = xt[i]
59
            labels = y[i]
60
            # convert SMILES to molecular representation using rdkit
61
            c_size, features, edge_index = smile_graph[smiles]
62
            # make the graph ready for PyTorch Geometrics GCN algorithms:
63
            GCNData = DATA.Data(x=torch.Tensor(features),
64
                                edge_index=torch.LongTensor(edge_index).transpose(1, 0),
65
                                y=torch.FloatTensor([labels]))
66
            GCNData.target = torch.LongTensor([target])
67
            GCNData.__setitem__('c_size', torch.LongTensor([c_size]))
68
            # append graph, label and target sequence to data list
69
            data_list.append(GCNData)
70
71
        if self.pre_filter is not None:
72
            data_list = [data for data in data_list if self.pre_filter(data)]
73
74
        if self.pre_transform is not None:
75
            data_list = [self.pre_transform(data) for data in data_list]
76
        print('Graph construction done. Saving to file.')
77
        data, slices = self.collate(data_list)
78
        # save preprocessed data:
79
        torch.save((data, slices), self.processed_paths[0])
80
81
def rmse(y,f):
82
    rmse = sqrt(((y - f)**2).mean(axis=0))
83
    return rmse
84
def mse(y,f):
85
    mse = ((y - f)**2).mean(axis=0)
86
    return mse
87
def pearson(y,f):
88
    rp = np.corrcoef(y, f)[0,1]
89
    return rp
90
def spearman(y,f):
91
    rs = stats.spearmanr(y, f)[0]
92
    return rs
93
def ci(y,f):
94
    ind = np.argsort(y)
95
    y = y[ind]
96
    f = f[ind]
97
    i = len(y)-1
98
    j = i-1
99
    z = 0.0
100
    S = 0.0
101
    while i > 0:
102
        while j >= 0:
103
            if y[i] > y[j]:
104
                z = z+1
105
                u = f[i] - f[j]
106
                if u > 0:
107
                    S = S + 1
108
                elif u == 0:
109
                    S = S + 0.5
110
            j = j - 1
111
        i = i - 1
112
        j = i-1
113
    ci = S/z
114
    return ci