--- a +++ b/unimol/data/distance_dataset.py @@ -0,0 +1,64 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from scipy.spatial import distance_matrix +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class DistanceDataset(BaseWrapperDataset): + def __init__(self, dataset): + super().__init__(dataset) + self.dataset = dataset + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + pos = self.dataset[idx].view(-1, 3).numpy() + dist = distance_matrix(pos, pos).astype(np.float32) + return torch.from_numpy(dist) + + +class EdgeTypeDataset(BaseWrapperDataset): + def __init__(self, dataset: torch.utils.data.Dataset, num_types: int): + self.dataset = dataset + self.num_types = num_types + + @lru_cache(maxsize=16) + def __getitem__(self, index: int): + node_input = self.dataset[index].clone() + offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1) + return offset + + +class CrossDistanceDataset(BaseWrapperDataset): + def __init__(self, mol_dataset, pocket_dataset): + super().__init__(mol_dataset) + self.dataset = mol_dataset + self.mol_dataset = mol_dataset + self.pocket_dataset = pocket_dataset + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + mol_pos = self.mol_dataset[idx].view(-1, 3).numpy() + pocket_pos = self.pocket_dataset[idx].view(-1, 3).numpy() + dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32) + assert dist.shape[0] == self.mol_dataset[idx].shape[0] + assert dist.shape[1] == self.pocket_dataset[idx].shape[0] + return torch.from_numpy(dist) + +class CrossEdgeTypeDataset(BaseWrapperDataset): + def __init__(self, mol_dataset, pocket_dataset, num_types: int): + self.dataset = mol_dataset + self.mol_dataset = mol_dataset + self.pocket_dataset = pocket_dataset + self.num_types = num_types + + @lru_cache(maxsize=16) + def __getitem__(self, index: int): + mol_node_input = self.mol_dataset[index].clone() + pocket_node_input = self.pocket_dataset[index].clone() + offset = mol_node_input.view(-1, 1) * self.num_types + pocket_node_input.view(1, -1) + return offset \ No newline at end of file