--- a +++ b/unimol/data/remove_hydrogen_dataset.py @@ -0,0 +1,144 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class RemoveHydrogenDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + atoms, + coordinates, + remove_hydrogen=False, + remove_polar_hydrogen=False, + ): + self.dataset = dataset + self.atoms = atoms + self.coordinates = coordinates + self.remove_hydrogen = remove_hydrogen + self.remove_polar_hydrogen = remove_polar_hydrogen + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + + if self.remove_hydrogen: + mask_hydrogen = atoms != "H" + atoms = atoms[mask_hydrogen] + #print(coordinates.shape) + coordinates = coordinates[mask_hydrogen] + if not self.remove_hydrogen and self.remove_polar_hydrogen: + end_idx = 0 + for i, atom in enumerate(atoms[::-1]): + if atom != "H": + break + else: + end_idx = i + 1 + if end_idx != 0: + atoms = atoms[:-end_idx] + coordinates = coordinates[:-end_idx] + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class RemoveHydrogenResiduePocketDataset(BaseWrapperDataset): + def __init__(self, dataset, atoms, residues, coordinates, remove_hydrogen=True): + self.dataset = dataset + self.atoms = atoms + self.residues = residues + self.coordinates = coordinates + self.remove_hydrogen = remove_hydrogen + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + residues = dd[self.residues] + coordinates = dd[self.coordinates] + if len(atoms) != len(residues): + min_len = min(len(atoms), len(residues)) + atoms = atoms[:min_len] + residues = residues[:min_len] + coordinates = coordinates[:min_len, :] + + if self.remove_hydrogen: + mask_hydrogen = atoms != "H" + atoms = atoms[mask_hydrogen] + residues = residues[mask_hydrogen] + coordinates = coordinates[mask_hydrogen] + + dd[self.atoms] = atoms + dd[self.residues] = residues + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class RemoveHydrogenPocketDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + atoms, + coordinates, + remove_hydrogen=True, + remove_polar_hydrogen=False, + ): + self.dataset = dataset + self.atoms = atoms + self.coordinates = coordinates + self.remove_hydrogen = remove_hydrogen + self.remove_polar_hydrogen = remove_polar_hydrogen + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + + if self.remove_hydrogen: + mask_hydrogen = atoms != "H" + atoms = atoms[mask_hydrogen] + coordinates = coordinates[mask_hydrogen] + if not self.remove_hydrogen and self.remove_polar_hydrogen: + end_idx = 0 + for i, atom in enumerate(atoms[::-1]): + if atom != "H": + break + else: + end_idx = i + 1 + if end_idx != 0: + atoms = atoms[:-end_idx] + coordinates = coordinates[:-end_idx] + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch)