--- a
+++ b/unimol/utils/docking_utils.py
@@ -0,0 +1,216 @@
+# Copyright (c) DP Techonology, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+from rdkit import Chem
+from rdkit.Chem import AllChem
+from rdkit import RDLogger
+
+RDLogger.DisableLog("rdApp.*")
+import warnings
+
+warnings.filterwarnings(action="ignore")
+from rdkit.Chem import rdMolTransforms
+import copy
+import lmdb
+import pickle
+import pandas as pd
+
+
+def get_torsions(m, removeHs=True):
+    if removeHs:
+        m = Chem.RemoveHs(m)
+    torsionList = []
+    torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"
+    torsionQuery = Chem.MolFromSmarts(torsionSmarts)
+    matches = m.GetSubstructMatches(torsionQuery)
+    for match in matches:
+        idx2 = match[0]
+        idx3 = match[1]
+        bond = m.GetBondBetweenAtoms(idx2, idx3)
+        jAtom = m.GetAtomWithIdx(idx2)
+        kAtom = m.GetAtomWithIdx(idx3)
+        for b1 in jAtom.GetBonds():
+            if b1.GetIdx() == bond.GetIdx():
+                continue
+            idx1 = b1.GetOtherAtomIdx(idx2)
+            for b2 in kAtom.GetBonds():
+                if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()):
+                    continue
+                idx4 = b2.GetOtherAtomIdx(idx3)
+                # skip 3-membered rings
+                if idx4 == idx1:
+                    continue
+                # skip torsions that include hydrogens
+                if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or (
+                    m.GetAtomWithIdx(idx4).GetAtomicNum() == 1
+                ):
+                    continue
+                if m.GetAtomWithIdx(idx4).IsInRing():
+                    torsionList.append((idx4, idx3, idx2, idx1))
+                    break
+                else:
+                    torsionList.append((idx1, idx2, idx3, idx4))
+                    break
+            break
+    return torsionList
+
+
+def SetDihedral(conf, atom_idx, new_vale):
+    rdMolTransforms.SetDihedralRad(
+        conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale
+    )
+
+
+def single_conf_gen_bonds(tgt_mol, num_confs=1000, seed=42, removeHs=True):
+    mol = copy.deepcopy(tgt_mol)
+    mol = Chem.AddHs(mol)
+    allconformers = AllChem.EmbedMultipleConfs(
+        mol, numConfs=num_confs, randomSeed=seed, clearConfs=True
+    )
+    if removeHs:
+        mol = Chem.RemoveHs(mol)
+    rotable_bonds = get_torsions(mol, removeHs=removeHs)
+    for i in range(len(allconformers)):
+        np.random.seed(i)
+        values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds))
+        for idx in range(len(rotable_bonds)):
+            SetDihedral(mol.GetConformers()[i], rotable_bonds[idx], values[idx])
+        Chem.rdMolTransforms.CanonicalizeConformer(mol.GetConformers()[i])
+    return mol
+
+
+def load_lmdb_data(lmdb_path, key):
+    env = lmdb.open(
+        lmdb_path,
+        subdir=False,
+        readonly=True,
+        lock=False,
+        readahead=False,
+        meminit=False,
+        max_readers=256,
+    )
+    txn = env.begin()
+    _keys = list(txn.cursor().iternext(values=False))
+    collects = []
+    for idx in range(len(_keys)):
+        datapoint_pickled = txn.get(f"{idx}".encode("ascii"))
+        data = pickle.loads(datapoint_pickled)
+        collects.append(data[key])
+    return collects
+
+
+def docking_data_pre(raw_data_path, predict_path):
+
+    mol_list = load_lmdb_data(raw_data_path, "mol_list")
+    mol_list = [Chem.RemoveHs(mol) for items in mol_list for mol in items]
+    predict = pd.read_pickle(predict_path)
+    (
+        smi_list,
+        pocket_list,
+        pocket_coords_list,
+        distance_predict_list,
+        holo_distance_predict_list,
+        holo_coords_list,
+        holo_center_coords_list,
+    ) = ([], [], [], [], [], [], [])
+    for batch in predict:
+        sz = batch["atoms"].size(0)
+        for i in range(sz):
+            smi_list.append(batch["smi_name"][i])
+            pocket_list.append(batch["pocket_name"][i])
+
+            distance_predict = batch["cross_distance_predict"][i]
+            token_mask = batch["atoms"][i] > 2
+            pocket_token_mask = batch["pocket_atoms"][i] > 2
+            distance_predict = distance_predict[token_mask][:, pocket_token_mask]
+            pocket_coords = batch["pocket_coordinates"][i]
+            pocket_coords = pocket_coords[pocket_token_mask, :]
+
+            holo_distance_predict = batch["holo_distance_predict"][i]
+            holo_distance_predict = holo_distance_predict[token_mask][:, token_mask]
+
+            holo_coordinates = batch["holo_coordinates"][i]
+            holo_coordinates = holo_coordinates[token_mask, :]
+            holo_center_coordinates = batch["holo_center_coordinates"][i][:3]
+
+            pocket_coords = pocket_coords.numpy().astype(np.float32)
+            distance_predict = distance_predict.numpy().astype(np.float32)
+            holo_distance_predict = holo_distance_predict.numpy().astype(np.float32)
+            holo_coords = holo_coordinates.numpy().astype(np.float32)
+
+            pocket_coords_list.append(pocket_coords)
+            distance_predict_list.append(distance_predict)
+            holo_distance_predict_list.append(holo_distance_predict)
+            holo_coords_list.append(holo_coords)
+            holo_center_coords_list.append(holo_center_coordinates)
+
+    return (
+        mol_list,
+        smi_list,
+        pocket_list,
+        pocket_coords_list,
+        distance_predict_list,
+        holo_distance_predict_list,
+        holo_coords_list,
+        holo_center_coords_list,
+    )
+
+
+def ensemble_iterations(
+    mol_list,
+    smi_list,
+    pocket_list,
+    pocket_coords_list,
+    distance_predict_list,
+    holo_distance_predict_list,
+    holo_coords_list,
+    holo_center_coords_list,
+    tta_times=10,
+):
+    sz = len(mol_list)
+    for i in range(sz // tta_times):
+        start_idx, end_idx = i * tta_times, (i + 1) * tta_times
+        distance_predict_tta = distance_predict_list[start_idx:end_idx]
+        holo_distance_predict_tta = holo_distance_predict_list[start_idx:end_idx]
+
+        mol = copy.deepcopy(mol_list[start_idx])
+        rdkit_mol = single_conf_gen_bonds(
+            mol, num_confs=tta_times, seed=42, removeHs=True
+        )
+        sz = len(rdkit_mol.GetConformers())
+        initial_coords_list = [
+            rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)
+            for i in range(sz)
+        ]
+
+        yield [
+            initial_coords_list,
+            mol,
+            smi_list[start_idx],
+            pocket_list[start_idx],
+            pocket_coords_list[start_idx],
+            distance_predict_tta,
+            holo_distance_predict_tta,
+            holo_coords_list[start_idx],
+            holo_center_coords_list[start_idx],
+        ]
+
+
+def rmsd_func(holo_coords, predict_coords):
+    if predict_coords is not np.nan:
+        sz = holo_coords.shape
+        rmsd = np.sqrt(np.sum((predict_coords - holo_coords) ** 2) / sz[0])
+        return rmsd
+    return 1000.0
+
+
+def print_results(rmsd_results):
+    print("RMSD < 1.0 : ", np.mean(rmsd_results < 1.0))
+    print("RMSD < 1.5 : ", np.mean(rmsd_results < 1.5))
+    print("RMSD < 2.0 : ", np.mean(rmsd_results < 2.0))
+    print("RMSD < 3.0 : ", np.mean(rmsd_results < 3.0))
+    print("RMSD < 5.0 : ", np.mean(rmsd_results < 5.0))
+    print("avg RMSD : ", np.mean(rmsd_results))