Switch to side-by-side view

--- a
+++ b/unimol/data/pocket2mol_dataset.py
@@ -0,0 +1,403 @@
+from functools import lru_cache
+import sys
+import pickle
+import random
+import networkx as nx
+import numpy as np
+import torch
+import rdkit
+from rdkit import Chem
+from rdkit.Chem import AllChem
+sys.path.append('..')
+
+import numpy as np
+from unicore.data import BaseWrapperDataset
+
+from . import data_utils
+from unimol.utils import geom
+
+def gen_conformation(mol, num_conf=20, num_worker=8, keepHs=False):
+    try:
+        mol = Chem.AddHs(mol)
+        AllChem.EmbedMultipleConfs(mol, numConfs=num_conf, numThreads=num_worker, pruneRmsThresh=0.1, maxAttempts=5, useRandomCoords=False)
+        AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=num_worker)
+        if not keepHs:
+            mol = Chem.RemoveHs(mol)
+        return mol
+    except:
+        return None
+
+class FragmentConformationDataset(BaseWrapperDataset):
+    def __init__(
+        self,
+        dataset,
+        seed,
+        vocab,
+        conf_vocab,
+        use_pocket=True,
+        is_train=True
+    ):
+        self.dataset = dataset
+        self.seed = seed
+        self.use_pocket = use_pocket
+        self.conf_vocab = Vocabulary(vocab, conf_vocab)
+        self.is_train = is_train
+        self.set_epoch(None)
+    
+    def set_epoch(self, epoch, **unused):
+        super().set_epoch(epoch)
+        self.epoch = epoch
+        
+
+    def parse_frag_mol(self, frag_mol):
+        atom_types = [a.GetSymbol() for a in frag_mol.GetAtoms()]
+        atom_coords = np.array(frag_mol.GetConformer(0).GetPositions())
+        return {'atom_types': atom_types, 'atom_coords': atom_coords}
+    
+    def parse_frag_idx(self, vocab_conf, full_mol, atom_map):
+        if vocab_conf.GetNumConformers() == 0:
+            smi = Chem.MolToSmiles(vocab_conf)
+            vocab_conf = Chem.MolFromSmiles(smi)
+            vocab_conf = gen_conformation(vocab_conf, num_conf=1, num_worker=1, keepHs=True)
+        mol = Chem.RWMol(full_mol)
+        atom_idx = list(range(full_mol.GetNumAtoms()))
+        for i, atom in enumerate(full_mol.GetAtoms()):
+            if atom.GetAtomMapNum() not in atom_map:
+                atom_idx[i] = -1
+        for i in range(len(atom_idx) - 1, -1, -1):
+            if atom_idx[i] == -1:
+                mol.RemoveAtom(i)
+        mol = mol.GetMol()
+        #mol = Chem.RemoveHs(mol)
+        smi = Chem.MolToSmiles(mol)
+        #find the map num in smiles
+        map_num = []
+        smi_p = smi.split('[')
+        for i in range(1, len(smi_p)):
+            if ':' in smi_p[i]:
+                end_idx = smi_p[i].split(':')[1].index(']')
+                map_num.append(int(smi_p[i].split(':')[1][:end_idx]))
+
+        vocab_conf = Chem.RemoveHs(vocab_conf)
+        for i, atom in enumerate(vocab_conf.GetAtoms()):
+            if atom.GetSymbol() != 'H':
+                atom.SetAtomMapNum(map_num[i])
+        vocab_conf = Chem.AddHs(vocab_conf, addCoords=True)
+
+        if torch.isnan(torch.from_numpy(np.array(vocab_conf.GetConformer(0).GetPositions()))).any():
+            vocab_conf = gen_conformation(vocab_conf, num_conf=1, num_worker=1, keepHs=True)
+
+        return vocab_conf
+
+
+    def pocket_atom(self, atom):
+        if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
+            return atom[1]
+        else:
+            return atom[0]
+    
+    def check_leaf(self, edges, index):
+        out_degree = 0
+        for edge in edges:
+            if edge[0] == index:
+                out_degree += 1
+        if out_degree == 0:
+            return True
+        else:
+            return False
+
+    def pocket_atom(self, atom):
+        if atom[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
+            return atom[1]
+        else:
+            return atom[0]
+    
+    def check_leaf(self, edges, index):
+        out_degree = 0
+        for edge in edges:
+            if edge[0] == index:
+                out_degree += 1
+        if out_degree == 0:
+            return True
+        else:
+            return False
+
+    @lru_cache(maxsize=16)
+    def __cached_item__(self, index: int, epoch: int):
+        
+        random.seed(self.seed + epoch)
+        #pocket
+        if self.use_pocket:
+            pocket_atoms = np.array(
+            [self.pocket_atom(item) for item in self.dataset[index]['pocket_atom']]
+        )
+        pocket_coordinates = np.stack(self.dataset[index]['pocket_coord'])
+
+        full_mol = self.dataset[index]['frags']['mol']
+
+        frag_graph = nx.Graph()
+        edges = self.dataset[index]['frags']['frag_edges']
+        frag_graph.add_edges_from(np.array(edges))
+        if len(edges) == 0:
+            frag_mol_idx = self.dataset[index]['frags']['frag_idx'][0]
+            frag_mol = self.conf_vocab.conf[frag_mol_idx]
+            f_mol_noH_num = len([atom for atom in frag_mol.GetAtoms() if atom.GetSymbol() != 'H'])
+            frag_mol = Chem.AddHs(frag_mol, addCoords=True)
+            frag_mol_data = self.parse_frag_mol(frag_mol)
+            return {
+                'atom_types': np.array([], dtype=str),
+                'atom_coords': np.zeros((1, 3), dtype=np.float32),
+                'focal_atom_local': 0,#place holder
+                'attach_atom_local': 0,#place holder
+                'focal_atom': 0,#place holder
+                'attach_atom': 0,#place holder
+                'frag_atom_types': frag_mol_data['atom_types'],
+                'frag_atom_coords': frag_mol_data['atom_coords'],
+                'end': True,
+                'torsion_output_prev': 0,#place holder
+                'coords_input_prev': frag_mol_data['atom_coords'],
+                'atom_types_withfocal': frag_mol_data['atom_types'],
+                'pocket_atoms': pocket_atoms,
+                'pocket_coordinates': pocket_coordinates,
+                'first': True,
+                'symmetric': True
+            }
+        bfs_edges = list(nx.bfs_edges(frag_graph, 0))
+        #reorder the link and local link
+        link = self.dataset[index]['frags']['links']
+        local_link = self.dataset[index]['frags']['links_local']
+        reorrder_link, reorrder_local_link = [], []
+        for i, b_edge in enumerate(bfs_edges):
+            for j, o_edge in enumerate(edges):
+                if b_edge[0] == o_edge[0] and b_edge[1] == o_edge[1]:
+                    reorrder_link.append(link[j])
+                    reorrder_local_link.append(local_link[j])
+                elif b_edge[0] == o_edge[1] and b_edge[1] == o_edge[0]:
+                    reorrder_link.append([link[j][1], link[j][0]])
+                    reorrder_local_link.append([local_link[j][1], local_link[j][0]])
+        self.dataset[index]['frags']['links'] = reorrder_link
+        self.dataset[index]['frags']['links_local'] = reorrder_local_link
+        #clip a random subgraph
+        bfs_edges_full = bfs_edges.copy()
+        clip_step = random.randint(0, len(bfs_edges))
+        #print('clip_step', clip_step)
+        start_frag = bfs_edges[0][0]
+        if clip_step != 0:
+            bfs_edges = bfs_edges[:clip_step]
+            focal_frag, attach_frag = bfs_edges[-1]
+        else:
+            focal_frag, attach_frag = bfs_edges[0]
+
+        for i, atom in enumerate(full_mol.GetAtoms()):
+            atom.SetAtomMapNum(i + 1)
+        
+        end = (clip_step == len(bfs_edges_full))
+        if clip_step == 0:
+            #frag
+            #print('first frag')
+            frag_mol_idx = self.dataset[index]['frags']['frag_idx'][start_frag]
+            frag_mol = self.conf_vocab.conf[frag_mol_idx]
+            if frag_mol is None:
+                print(frag_mol_idx, 'is None')
+            frag_mol = self.parse_frag_idx(frag_mol, full_mol, self.dataset[index]['frags']['map'][start_frag])
+            frag_mol_data = self.parse_frag_mol(frag_mol)
+            return {
+                'atom_types': np.array([], dtype=str),
+                'atom_coords': np.zeros((1, 3), dtype=np.float32),
+                'focal_atom_local': 0,#place holder
+                'attach_atom_local': 0,#place holder
+                'focal_atom': 0,#place holder
+                'attach_atom': 0,#place holder
+                'frag_atom_types': frag_mol_data['atom_types'],
+                'frag_atom_coords': frag_mol_data['atom_coords'],
+                'end': end,
+                'torsion_output_prev': 0,#place holder
+                'coords_input_prev': frag_mol_data['atom_coords'],
+                'atom_types_withfocal': frag_mol_data['atom_types'],
+                'pocket_atoms': pocket_atoms,
+                'pocket_coordinates': pocket_coordinates,
+                'first': True,
+                'symmetric': True
+            }
+        
+        
+        clip_frag_idx = [e[0] for e in bfs_edges] + [e[1] for e in bfs_edges[:-1]]
+        clip_frag_idx = np.unique(clip_frag_idx)
+        frag_attach_idx = bfs_edges[-1][1]
+        clip_map = []
+        for i in range(len(clip_frag_idx)):
+            clip_map.extend(self.dataset[index]['frags']['map'][clip_frag_idx[i]])
+        clip_map_attach = clip_map + self.dataset[index]['frags']['map'][frag_attach_idx]
+        #print('frag_node_map', self.dataset[index]['frags']['map'])
+
+        #get part mol
+        part_mol = Chem.RWMol(full_mol)
+        atom_idx = list(range(full_mol.GetNumAtoms()))
+        for i, atom in enumerate(full_mol.GetAtoms()):
+            if atom.GetSymbol() == 'H':
+                #remove if neighbor is not in clip_map
+                neighbor_atom = [atom.GetAtomMapNum() for atom in part_mol.GetAtomWithIdx(i).GetNeighbors()]
+                neighbor_map = [atom.GetAtomMapNum() for atom in part_mol.GetAtomWithIdx(i).GetNeighbors()][0]
+                if neighbor_map not in clip_map:
+                    atom_idx[i] = -1
+                if atom.GetAtomMapNum() not in clip_map:
+                    atom_idx[i] = -1
+            else:
+                if atom.GetAtomMapNum() not in clip_map:
+                    atom_idx[i] = -1
+        for i in range(len(atom_idx) - 1, -1, -1):
+            if atom_idx[i] == -1:
+                part_mol.RemoveAtom(i)
+        
+        frag_exp_link = []
+        for i in range(len(clip_frag_idx)):
+            for e_d, e in enumerate(bfs_edges_full):
+                if e[0] == clip_frag_idx[i] and e[1] not in clip_frag_idx:
+                    frag_exp_link.append(self.dataset[index]['frags']['links'][e_d][0])
+        for link_mp in frag_exp_link:
+            add_map = [i for i, atom in enumerate(part_mol.GetAtoms()) if atom.GetAtomMapNum() == link_mp][0]
+            part_mol.AddAtom(Chem.Atom(1))
+            part_mol.AddBond(add_map, part_mol.GetNumAtoms() - 1, Chem.rdchem.BondType.SINGLE)
+        part_mol = part_mol.GetMol()
+        part_mol = Chem.RemoveHs(part_mol)
+        part_mol = Chem.AddHs(part_mol, addCoords=True)
+
+
+        part_mol_atom_types = [atom.GetSymbol() for atom in part_mol.GetAtoms()]
+        part_mol_atom_coords = np.array([part_mol.GetConformer().GetAtomPosition(i) for i in range(part_mol.GetNumAtoms())])
+
+        #get part mol with attach
+        part_mol_attach = Chem.RWMol(full_mol)
+        atom_idx = list(range(full_mol.GetNumAtoms()))
+        for i, atom in enumerate(full_mol.GetAtoms()):
+            if full_mol.GetAtomWithIdx(i).GetSymbol() == 'H':
+                #remove if neighbor is not in clip_map
+                neighbor = [atom.GetAtomMapNum() for atom in part_mol_attach.GetAtomWithIdx(i).GetNeighbors()][0]
+                if neighbor not in clip_map_attach:
+                    atom_idx[i] = -1
+            else:
+                if atom.GetAtomMapNum() not in clip_map_attach:
+                    atom_idx[i] = -1
+        #print([i for i, idx in enumerate(atom_idx) if idx == -1])
+        for i in range(len(atom_idx) - 1, -1, -1):
+            if atom_idx[i] == -1:
+                part_mol_attach.RemoveAtom(i)
+        #if not self.check_leaf(edges, frag_attach_idx):
+        frag_exp_link = []
+        clip_frag_attach_idx = list(clip_frag_idx) + [frag_attach_idx]
+        for i in range(len(clip_frag_attach_idx)):
+            for e_d, e in enumerate(bfs_edges_full):
+                if e[0] == clip_frag_attach_idx[i] and e[1] not in clip_frag_attach_idx:
+                    frag_exp_link.append(self.dataset[index]['frags']['links'][e_d][0])
+        for link_mp in frag_exp_link:
+            add_map = [i for i, atom in enumerate(part_mol_attach.GetAtoms()) if atom.GetAtomMapNum() == link_mp][0]
+            part_mol_attach.AddAtom(Chem.Atom(1))
+            part_mol_attach.AddBond(add_map, part_mol_attach.GetNumAtoms() - 1, Chem.rdchem.BondType.SINGLE)
+            #print('add H atom symbol', part_mol_attach.GetAtomWithIdx(add_map).GetSymbol())
+        #else:
+        #    print('leaf node')
+        part_mol_attach = part_mol_attach.GetMol()
+        part_mol_attach = Chem.RemoveHs(part_mol_attach)
+        part_mol_attach = Chem.AddHs(part_mol_attach, addCoords=True)
+
+        part_mol_attach_atom_types = [atom.GetSymbol() for atom in part_mol_attach.GetAtoms()]
+        part_mol_attach_atom_coords = np.array([part_mol_attach.GetConformer().GetAtomPosition(i) for i in range(part_mol_attach.GetNumAtoms())])
+        '''
+        part_mol_atom = [self.dataset[index]['frags']['map'][e[0]] for e in bfs_edges] + \
+                        [self.dataset[index]['frags']['map'][e[1]] for e in bfs_edges[:-1]]
+        part_mol_atom = np.concatenate(part_mol_atom, axis=0)
+        part_mol_atom = np.unique(part_mol_atom)
+        part_mol_atom_types = self.dataset[index]['atom_types'][part_mol_atom]
+        part_mol_atom_coords = self.dataset[index]['atom_coords'][part_mol_atom]
+        '''
+
+        '''
+        #add focal atom
+        part_mol_atom_withfocal = [self.dataset[index]['frags']['map'][e[0]] for e in bfs_edges] + \
+                        [self.dataset[index]['frags']['map'][e[1]] for e in bfs_edges]
+        part_mol_atom_withfocal = np.concatenate(part_mol_atom_withfocal, axis=0)
+        part_mol_atom_withfocal = np.unique(part_mol_atom_withfocal)
+        part_mol_atom_types_withfocal = self.dataset[index]['atom_types'][part_mol_atom_withfocal]
+        part_mol_atom_coords_withfocal = self.dataset[index]['atom_coords'][part_mol_atom_withfocal]
+        '''
+        focal_atom_local = [i for i, atom in enumerate(part_mol.GetAtoms()) if atom.GetAtomMapNum() == self.dataset[index]['frags']['links'][clip_step - 1][0]][0]
+        focal_atom = [i for i, atom in enumerate(part_mol_attach.GetAtoms()) if atom.GetAtomMapNum() == self.dataset[index]['frags']['links'][clip_step - 1][0]][0]
+        #focal_atom = self.dataset[index]['frags']['map'][focal_frag][focal_atom_local]
+        
+        #frag
+        frag_mol_idx = self.dataset[index]['frags']['frag_idx'][attach_frag]
+        frag_mol = self.conf_vocab.conf[frag_mol_idx]
+
+        frag_mol = self.parse_frag_idx(frag_mol, self.dataset[index]['frags']['mol'], self.dataset[index]['frags']['map'][attach_frag])
+        frag_mol_data = self.parse_frag_mol(frag_mol)
+
+        if frag_mol is None:
+                print(frag_mol_idx, 'is None')
+
+        attach_atom_local = [i for i, atom in enumerate(frag_mol.GetAtoms()) if atom.GetAtomMapNum() == self.dataset[index]['frags']['links'][clip_step - 1][1]][0]
+        attach_atom = [i for i, atom in enumerate(part_mol_attach.GetAtoms()) if atom.GetAtomMapNum() == self.dataset[index]['frags']['links'][clip_step - 1][1]][0]
+        #attach_atom = self.dataset[index]['frags']['map'][bfs_edges[-1][1]][attach_atom_local]
+
+
+        #torsion angles
+        #prev_edge = bfs_edges[-1]
+        '''
+        prev_link_local = self.dataset[index]['frags']['link_atoms'][bfs_edges.index(prev_edge) - 1]
+        prev_link = (self.dataset[index]['frags']['map'][prev_edge[0]][prev_link_local[0]], 
+                self.dataset[index]['frags']['map'][prev_edge[1]][prev_link_local[1]])
+        index_rotate = [prev_link[1]] + self.dataset[index]['frags']['map'][focal_frag]
+        index_parent = np.concatenate([self.dataset[index]['frags']['map'][e[0]] for e in bfs_edges[:-1]], axis=0) + [prev_link[0]]
+        '''
+        index_rotate = [i for i, atom in enumerate(part_mol_attach.GetAtoms()) if atom.GetAtomMapNum() in self.dataset[index]['frags']['map'][attach_frag]]
+        index_rotate.remove(attach_atom)
+        index_rotate = [attach_atom] + index_rotate
+        index_parent = []
+        for e in bfs_edges[:-1]:
+            index_parent += self.dataset[index]['frags']['map'][e[0]]
+            index_parent += self.dataset[index]['frags']['map'][e[1]]
+        index_parent += self.dataset[index]['frags']['map'][focal_frag]
+        index_parent = list(np.unique(index_parent))
+        index_parent = [i for i, atom in enumerate(part_mol_attach.GetAtoms()) if atom.GetAtomMapNum() in index_parent]
+        index_parent.remove(focal_atom)
+        index_parent = index_parent + [focal_atom]
+        #add hydrogen
+        # get the hydrogen that connects to the rotate atoms
+        index_rotate_h = []
+        for i in index_rotate:
+            for j in part_mol_attach.GetAtomWithIdx(i).GetNeighbors():
+                if j.GetAtomicNum() == 1:
+                    index_rotate_h.append(j.GetIdx())
+        index_rotate += index_rotate_h
+        # get the hydrogen that connects to the parent atoms
+        index_parent_h = []
+        for i in index_parent:
+            for j in part_mol_attach.GetAtomWithIdx(i).GetNeighbors():
+                if j.GetAtomicNum() == 1:
+                    index_parent_h.append(j.GetIdx())
+        index_parent = index_parent_h + index_parent
+        coords_input_prev, torsion_output_prev = geom.change_torsion(part_mol_attach_atom_coords, [index_parent, index_rotate])
+        symmetric = (len(index_rotate) == 0 )
+
+        return {
+            'atom_types': part_mol_atom_types,
+            'atom_coords': part_mol_atom_coords,
+            'focal_atom_local': focal_atom_local,
+            'attach_atom_local': attach_atom_local,
+            'focal_atom': focal_atom,
+            'attach_atom': attach_atom,
+            'frag_atom_types': frag_mol_data['atom_types'],
+            'frag_atom_coords': frag_mol_data['atom_coords'],
+            'end': end,
+            'torsion_output_prev': torsion_output_prev,
+            'coords_input_prev': coords_input_prev,
+            'atom_types_withfocal': part_mol_attach_atom_types,
+            'pocket_atoms': pocket_atoms,
+            'pocket_coordinates': pocket_coordinates,
+            'first': False,
+            'symmetric': symmetric
+        }
+    
+    def __getitem__(self, index: int):
+        item = self.__cached_item__(index, self.epoch)
+        return item
+    
\ No newline at end of file