Switch to side-by-side view

--- a
+++ b/unimol/data/cropping_dataset.py
@@ -0,0 +1,269 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+from functools import lru_cache
+import logging
+from unicore.data import BaseWrapperDataset
+from . import data_utils
+
+logger = logging.getLogger(__name__)
+
+
+class CroppingDataset(BaseWrapperDataset):
+    def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256):
+        self.dataset = dataset
+        self.seed = seed
+        self.atoms = atoms
+        self.coordinates = coordinates
+        self.max_atoms = max_atoms
+        self.set_epoch(None)
+
+    def set_epoch(self, epoch, **unused):
+        super().set_epoch(epoch)
+        self.epoch = epoch
+
+    @lru_cache(maxsize=16)
+    def __cached_item__(self, index: int, epoch: int):
+        dd = self.dataset[index].copy()
+        atoms = dd[self.atoms]
+        coordinates = dd[self.coordinates]
+        if self.max_atoms and len(atoms) > self.max_atoms:
+            with data_utils.numpy_seed(self.seed, epoch, index):
+                index = np.random.choice(len(atoms), self.max_atoms, replace=False)
+                atoms = np.array(atoms)[index]
+                coordinates = coordinates[index]
+        dd[self.atoms] = atoms
+        dd[self.coordinates] = coordinates.astype(np.float32)
+        return dd
+
+    def __getitem__(self, index: int):
+        return self.__cached_item__(index, self.epoch)
+
+
+class CroppingPocketDataset(BaseWrapperDataset):
+    def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256):
+        self.dataset = dataset
+        self.seed = seed
+        self.atoms = atoms
+        self.coordinates = coordinates
+        self.max_atoms = (
+            max_atoms  # max number of atoms in a molecule, None indicates no limit.
+        )
+        self.set_epoch(None)
+
+    def set_epoch(self, epoch, **unused):
+        super().set_epoch(epoch)
+        self.epoch = epoch
+
+    @lru_cache(maxsize=16)
+    def __cached_item__(self, index: int, epoch: int):
+        dd = self.dataset[index].copy()
+        atoms = dd[self.atoms]
+        coordinates = dd[self.coordinates]
+        #residue = dd["residue"]
+
+        # crop atoms according to their distance to the center of pockets
+        if self.max_atoms and len(atoms) > self.max_atoms:
+            with data_utils.numpy_seed(self.seed, epoch, index):
+                distance = np.linalg.norm(
+                    coordinates - coordinates.mean(axis=0), axis=1
+                )
+
+                def softmax(x):
+                    x -= np.max(x)
+                    x = np.exp(x) / np.sum(np.exp(x))
+                    return x
+
+                distance += 1  # prevent inf
+                weight = softmax(np.reciprocal(distance))
+                index = np.random.choice(
+                    len(atoms), self.max_atoms, replace=False, p=weight
+                )
+                atoms = atoms[index]
+                coordinates = coordinates[index]
+                #residue = residue[index]
+
+        dd[self.atoms] = atoms
+        dd[self.coordinates] = coordinates.astype(np.float32)
+        #dd["residue"] = residue
+        return dd
+
+    def __getitem__(self, index: int):
+        return self.__cached_item__(index, self.epoch)
+
+
+class CroppingResiduePocketDataset(BaseWrapperDataset):
+    def __init__(self, dataset, seed, atoms, residues, coordinates, max_atoms=256):
+        self.dataset = dataset
+        self.seed = seed
+        self.atoms = atoms
+        self.residues = residues
+        self.coordinates = coordinates
+        self.max_atoms = (
+            max_atoms  # max number of atoms in a molecule, None indicates no limit.
+        )
+
+        self.set_epoch(None)
+
+    def set_epoch(self, epoch, **unused):
+        super().set_epoch(epoch)
+        self.epoch = epoch
+
+    @lru_cache(maxsize=16)
+    def __cached_item__(self, index: int, epoch: int):
+        dd = self.dataset[index].copy()
+        atoms = dd[self.atoms]
+        residues = dd[self.residues]
+        coordinates = dd[self.coordinates]
+
+        residues_distance_map = {}
+
+        # crop atoms according to their distance to the center of pockets
+        if self.max_atoms and len(atoms) > self.max_atoms:
+            with data_utils.numpy_seed(self.seed, epoch, index):
+                distance = np.linalg.norm(
+                    coordinates - coordinates.mean(axis=0), axis=1
+                )
+                residues_ids, residues_distance = [], []
+                for res in residues:
+                    if res not in residues_ids:
+                        residues_ids.append(res)
+                        residues_distance.append(distance[residues == res].mean())
+                residues_ids = np.array(residues_ids)
+                residues_distance = np.array(residues_distance)
+
+                def softmax(x):
+                    x -= np.max(x)
+                    x = np.exp(x) / np.sum(np.exp(x))
+                    return x
+
+                residues_distance += 1  # prevent inf and smoothing out the distance
+                weight = softmax(np.reciprocal(residues_distance))
+                max_residues = self.max_atoms // (len(atoms) // (len(residues_ids) + 1))
+                if max_residues < 1:
+                    max_residues += 1
+                max_residues = min(max_residues, len(residues_ids))
+                residue_index = np.random.choice(
+                    len(residues_ids), max_residues, replace=False, p=weight
+                )
+                index = [
+                    i
+                    for i in range(len(atoms))
+                    if residues[i] in residues_ids[residue_index]
+                ]
+                atoms = atoms[index]
+                coordinates = coordinates[index]
+                residues = residues[index]
+
+        dd[self.atoms] = atoms
+        dd[self.coordinates] = coordinates.astype(np.float32)
+        dd[self.residues] = residues
+        return dd
+
+    def __getitem__(self, index: int):
+        return self.__cached_item__(index, self.epoch)
+
+
+class CroppingPocketDockingPoseDataset(BaseWrapperDataset):
+    def __init__(
+        self, dataset, seed, atoms, coordinates, holo_coordinates, max_atoms=256
+    ):
+        self.dataset = dataset
+        self.seed = seed
+        self.atoms = atoms
+        self.coordinates = coordinates 
+        self.max_atoms = max_atoms
+
+        self.set_epoch(None)
+
+    def set_epoch(self, epoch, **unused):
+        super().set_epoch(epoch)
+        self.epoch = epoch
+
+    @lru_cache(maxsize=16)
+    def __cached_item__(self, index: int, epoch: int):
+        dd = self.dataset[index].copy()
+        atoms = dd[self.atoms]
+        coordinates = dd[self.coordinates]
+        holo_coordinates = dd[self.holo_coordinates]
+
+        # crop atoms according to their distance to the center of pockets
+        #print(len(atoms))
+        if self.max_atoms and len(atoms) > self.max_atoms:
+            with data_utils.numpy_seed(self.seed, 1):
+                distance = np.linalg.norm(
+                    coordinates - coordinates.mean(axis=0), axis=1
+                )
+
+                def softmax(x):
+                    x -= np.max(x)
+                    x = np.exp(x) / np.sum(np.exp(x))
+                    return x
+
+                distance += 1  # prevent inf
+                weight = softmax(np.reciprocal(distance))
+                index = np.random.choice(
+                    len(atoms), self.max_atoms, replace=False, p=weight
+                )
+                atoms = atoms[index]
+                coordinates = coordinates[index]
+                holo_coordinates = holo_coordinates[index]
+
+        dd[self.atoms] = atoms
+        dd[self.coordinates] = coordinates.astype(np.float32)
+        dd[self.holo_coordinates] = holo_coordinates.astype(np.float32)
+        return dd
+
+    def __getitem__(self, index: int):
+        return self.__cached_item__(index, self.epoch)
+
+class CroppingPocketDockingPoseTestDataset(BaseWrapperDataset):
+    def __init__(
+        self, dataset, seed, atoms, coordinates, max_atoms=256
+    ):
+        self.dataset = dataset
+        self.seed = seed
+        self.atoms = atoms
+        self.coordinates = coordinates
+        self.max_atoms = max_atoms
+
+        self.set_epoch(None)
+
+    def set_epoch(self, epoch, **unused):
+        super().set_epoch(epoch)
+        self.epoch = epoch
+
+    @lru_cache(maxsize=16)
+    def __cached_item__(self, index: int, epoch: int):
+        dd = self.dataset[index].copy()
+        atoms = dd[self.atoms]
+        coordinates = dd[self.coordinates]
+
+        # crop atoms according to their distance to the center of pockets
+        if self.max_atoms and len(atoms) > self.max_atoms:
+            with data_utils.numpy_seed(1, 1):
+                distance = np.linalg.norm(
+                    coordinates - coordinates.mean(axis=0), axis=1
+                )
+
+                def softmax(x):
+                    x -= np.max(x)
+                    x = np.exp(x) / np.sum(np.exp(x))
+                    return x
+
+                distance += 1  # prevent inf
+                weight = softmax(np.reciprocal(distance))
+                index = np.random.choice(
+                    len(atoms), self.max_atoms, replace=False, p=weight
+                )
+                atoms = atoms[index]
+                coordinates = coordinates[index]
+
+        dd[self.atoms] = atoms
+        dd[self.coordinates] = coordinates.astype(np.float32)
+        return dd
+
+    def __getitem__(self, index: int):
+        return self.__cached_item__(index, self.epoch)