--- a +++ b/unimol/data/cropping_dataset.py @@ -0,0 +1,269 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from functools import lru_cache +import logging +from unicore.data import BaseWrapperDataset +from . import data_utils + +logger = logging.getLogger(__name__) + + +class CroppingDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.max_atoms = max_atoms + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, epoch, index): + index = np.random.choice(len(atoms), self.max_atoms, replace=False) + atoms = np.array(atoms)[index] + coordinates = coordinates[index] + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class CroppingPocketDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.max_atoms = ( + max_atoms # max number of atoms in a molecule, None indicates no limit. + ) + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + #residue = dd["residue"] + + # crop atoms according to their distance to the center of pockets + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, epoch, index): + distance = np.linalg.norm( + coordinates - coordinates.mean(axis=0), axis=1 + ) + + def softmax(x): + x -= np.max(x) + x = np.exp(x) / np.sum(np.exp(x)) + return x + + distance += 1 # prevent inf + weight = softmax(np.reciprocal(distance)) + index = np.random.choice( + len(atoms), self.max_atoms, replace=False, p=weight + ) + atoms = atoms[index] + coordinates = coordinates[index] + #residue = residue[index] + + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + #dd["residue"] = residue + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class CroppingResiduePocketDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, residues, coordinates, max_atoms=256): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.residues = residues + self.coordinates = coordinates + self.max_atoms = ( + max_atoms # max number of atoms in a molecule, None indicates no limit. + ) + + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + residues = dd[self.residues] + coordinates = dd[self.coordinates] + + residues_distance_map = {} + + # crop atoms according to their distance to the center of pockets + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, epoch, index): + distance = np.linalg.norm( + coordinates - coordinates.mean(axis=0), axis=1 + ) + residues_ids, residues_distance = [], [] + for res in residues: + if res not in residues_ids: + residues_ids.append(res) + residues_distance.append(distance[residues == res].mean()) + residues_ids = np.array(residues_ids) + residues_distance = np.array(residues_distance) + + def softmax(x): + x -= np.max(x) + x = np.exp(x) / np.sum(np.exp(x)) + return x + + residues_distance += 1 # prevent inf and smoothing out the distance + weight = softmax(np.reciprocal(residues_distance)) + max_residues = self.max_atoms // (len(atoms) // (len(residues_ids) + 1)) + if max_residues < 1: + max_residues += 1 + max_residues = min(max_residues, len(residues_ids)) + residue_index = np.random.choice( + len(residues_ids), max_residues, replace=False, p=weight + ) + index = [ + i + for i in range(len(atoms)) + if residues[i] in residues_ids[residue_index] + ] + atoms = atoms[index] + coordinates = coordinates[index] + residues = residues[index] + + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + dd[self.residues] = residues + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class CroppingPocketDockingPoseDataset(BaseWrapperDataset): + def __init__( + self, dataset, seed, atoms, coordinates, holo_coordinates, max_atoms=256 + ): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.max_atoms = max_atoms + + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + holo_coordinates = dd[self.holo_coordinates] + + # crop atoms according to their distance to the center of pockets + #print(len(atoms)) + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, 1): + distance = np.linalg.norm( + coordinates - coordinates.mean(axis=0), axis=1 + ) + + def softmax(x): + x -= np.max(x) + x = np.exp(x) / np.sum(np.exp(x)) + return x + + distance += 1 # prevent inf + weight = softmax(np.reciprocal(distance)) + index = np.random.choice( + len(atoms), self.max_atoms, replace=False, p=weight + ) + atoms = atoms[index] + coordinates = coordinates[index] + holo_coordinates = holo_coordinates[index] + + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + dd[self.holo_coordinates] = holo_coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + +class CroppingPocketDockingPoseTestDataset(BaseWrapperDataset): + def __init__( + self, dataset, seed, atoms, coordinates, max_atoms=256 + ): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.max_atoms = max_atoms + + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + + # crop atoms according to their distance to the center of pockets + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(1, 1): + distance = np.linalg.norm( + coordinates - coordinates.mean(axis=0), axis=1 + ) + + def softmax(x): + x -= np.max(x) + x = np.exp(x) / np.sum(np.exp(x)) + return x + + distance += 1 # prevent inf + weight = softmax(np.reciprocal(distance)) + index = np.random.choice( + len(atoms), self.max_atoms, replace=False, p=weight + ) + atoms = atoms[index] + coordinates = coordinates[index] + + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch)