Switch to side-by-side view

--- a
+++ b/unimol/data/distance_dataset.py
@@ -0,0 +1,64 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from scipy.spatial import distance_matrix
+from functools import lru_cache
+from unicore.data import BaseWrapperDataset
+
+
+class DistanceDataset(BaseWrapperDataset):
+    def __init__(self, dataset):
+        super().__init__(dataset)
+        self.dataset = dataset
+
+    @lru_cache(maxsize=16)
+    def __getitem__(self, idx):
+        pos = self.dataset[idx].view(-1, 3).numpy()
+        dist = distance_matrix(pos, pos).astype(np.float32)
+        return torch.from_numpy(dist)
+
+
+class EdgeTypeDataset(BaseWrapperDataset):
+    def __init__(self, dataset: torch.utils.data.Dataset, num_types: int):
+        self.dataset = dataset
+        self.num_types = num_types
+
+    @lru_cache(maxsize=16)
+    def __getitem__(self, index: int):
+        node_input = self.dataset[index].clone()
+        offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1)
+        return offset
+
+
+class CrossDistanceDataset(BaseWrapperDataset):
+    def __init__(self, mol_dataset, pocket_dataset):
+        super().__init__(mol_dataset)
+        self.dataset = mol_dataset
+        self.mol_dataset = mol_dataset
+        self.pocket_dataset = pocket_dataset
+
+    @lru_cache(maxsize=16)
+    def __getitem__(self, idx):
+        mol_pos = self.mol_dataset[idx].view(-1, 3).numpy()
+        pocket_pos = self.pocket_dataset[idx].view(-1, 3).numpy()
+        dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32)
+        assert dist.shape[0] == self.mol_dataset[idx].shape[0]
+        assert dist.shape[1] == self.pocket_dataset[idx].shape[0]
+        return torch.from_numpy(dist)
+
+class CrossEdgeTypeDataset(BaseWrapperDataset):
+    def __init__(self, mol_dataset, pocket_dataset, num_types: int):
+        self.dataset = mol_dataset
+        self.mol_dataset = mol_dataset
+        self.pocket_dataset = pocket_dataset
+        self.num_types = num_types
+
+    @lru_cache(maxsize=16)
+    def __getitem__(self, index: int):
+        mol_node_input = self.mol_dataset[index].clone()
+        pocket_node_input = self.pocket_dataset[index].clone()
+        offset = mol_node_input.view(-1, 1) * self.num_types + pocket_node_input.view(1, -1)
+        return offset
\ No newline at end of file