Diff of /dataset.py [000000] .. [607087]

Switch to unified view

a b/dataset.py
1
from itertools import accumulate
2
import numpy as np
3
import torch
4
from torch.utils.data import Dataset
5
6
7
class ProcessedLigandPocketDataset(Dataset):
8
    def __init__(self, npz_path, center=True, transform=None):
9
10
        self.transform = transform
11
12
        with np.load(npz_path, allow_pickle=True) as f:
13
            data = {key: val for key, val in f.items()}
14
15
        # split data based on mask
16
        self.data = {}
17
        for (k, v) in data.items():
18
            if k == 'names' or k == 'receptors':
19
                self.data[k] = v
20
                continue
21
22
            sections = np.where(np.diff(data['lig_mask']))[0] + 1 \
23
                if 'lig' in k \
24
                else np.where(np.diff(data['pocket_mask']))[0] + 1
25
            self.data[k] = [torch.from_numpy(x) for x in np.split(v, sections)]
26
27
            # add number of nodes for convenience
28
            if k == 'lig_mask':
29
                self.data['num_lig_atoms'] = \
30
                    torch.tensor([len(x) for x in self.data['lig_mask']])
31
            elif k == 'pocket_mask':
32
                self.data['num_pocket_nodes'] = \
33
                    torch.tensor([len(x) for x in self.data['pocket_mask']])
34
35
        if center:
36
            for i in range(len(self.data['lig_coords'])):
37
                mean = (self.data['lig_coords'][i].sum(0) +
38
                        self.data['pocket_coords'][i].sum(0)) / \
39
                       (len(self.data['lig_coords'][i]) + len(self.data['pocket_coords'][i]))
40
                self.data['lig_coords'][i] = self.data['lig_coords'][i] - mean
41
                self.data['pocket_coords'][i] = self.data['pocket_coords'][i] - mean
42
43
    def __len__(self):
44
        return len(self.data['names'])
45
46
    def __getitem__(self, idx):
47
        data = {key: val[idx] for key, val in self.data.items()}
48
        if self.transform is not None:
49
            data = self.transform(data)
50
        return data
51
52
    @staticmethod
53
    def collate_fn(batch):
54
        out = {}
55
        for prop in batch[0].keys():
56
57
            if prop == 'names' or prop == 'receptors':
58
                out[prop] = [x[prop] for x in batch]
59
            elif prop == 'num_lig_atoms' or prop == 'num_pocket_nodes' \
60
                    or prop == 'num_virtual_atoms':
61
                out[prop] = torch.tensor([x[prop] for x in batch])
62
            elif 'mask' in prop:
63
                # make sure indices in batch start at zero (needed for
64
                # torch_scatter)
65
                out[prop] = torch.cat([i * torch.ones(len(x[prop]))
66
                                       for i, x in enumerate(batch)], dim=0)
67
            else:
68
                out[prop] = torch.cat([x[prop] for x in batch], dim=0)
69
70
        return out