--- a +++ b/dataset.py @@ -0,0 +1,70 @@ +from itertools import accumulate +import numpy as np +import torch +from torch.utils.data import Dataset + + +class ProcessedLigandPocketDataset(Dataset): + def __init__(self, npz_path, center=True, transform=None): + + self.transform = transform + + with np.load(npz_path, allow_pickle=True) as f: + data = {key: val for key, val in f.items()} + + # split data based on mask + self.data = {} + for (k, v) in data.items(): + if k == 'names' or k == 'receptors': + self.data[k] = v + continue + + sections = np.where(np.diff(data['lig_mask']))[0] + 1 \ + if 'lig' in k \ + else np.where(np.diff(data['pocket_mask']))[0] + 1 + self.data[k] = [torch.from_numpy(x) for x in np.split(v, sections)] + + # add number of nodes for convenience + if k == 'lig_mask': + self.data['num_lig_atoms'] = \ + torch.tensor([len(x) for x in self.data['lig_mask']]) + elif k == 'pocket_mask': + self.data['num_pocket_nodes'] = \ + torch.tensor([len(x) for x in self.data['pocket_mask']]) + + if center: + for i in range(len(self.data['lig_coords'])): + mean = (self.data['lig_coords'][i].sum(0) + + self.data['pocket_coords'][i].sum(0)) / \ + (len(self.data['lig_coords'][i]) + len(self.data['pocket_coords'][i])) + self.data['lig_coords'][i] = self.data['lig_coords'][i] - mean + self.data['pocket_coords'][i] = self.data['pocket_coords'][i] - mean + + def __len__(self): + return len(self.data['names']) + + def __getitem__(self, idx): + data = {key: val[idx] for key, val in self.data.items()} + if self.transform is not None: + data = self.transform(data) + return data + + @staticmethod + def collate_fn(batch): + out = {} + for prop in batch[0].keys(): + + if prop == 'names' or prop == 'receptors': + out[prop] = [x[prop] for x in batch] + elif prop == 'num_lig_atoms' or prop == 'num_pocket_nodes' \ + or prop == 'num_virtual_atoms': + out[prop] = torch.tensor([x[prop] for x in batch]) + elif 'mask' in prop: + # make sure indices in batch start at zero (needed for + # torch_scatter) + out[prop] = torch.cat([i * torch.ones(len(x[prop])) + for i, x in enumerate(batch)], dim=0) + else: + out[prop] = torch.cat([x[prop] for x in batch], dim=0) + + return out