Switch to side-by-side view

--- a
+++ b/py_scripts/write_dude_multi.py
@@ -0,0 +1,237 @@
+import argparse
+import gzip
+import multiprocessing as mp
+import os
+import pickle
+import random
+
+import lmdb
+import numpy as np
+import pandas as pd
+import rdkit
+import rdkit.Chem.AllChem as AllChem
+import torch
+import tqdm
+from biopandas.mol2 import PandasMol2
+from biopandas.pdb import PandasPdb
+from rdkit import Chem, RDLogger
+from rdkit.Chem.MolStandardize import rdMolStandardize
+
+RDLogger.DisableLog('rdApp.*')
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--mol_data_path', type=str, default='/data/protein/DUD-E/raw/all')
+parser.add_argument('--lmdb_path', type=str, default='docked_dude_fromweb2D.lmdb')
+args = parser.parse_args()
+
+
+def gen_conformation(mol, num_conf=20, num_worker=8):
+    try:
+        mol = Chem.AddHs(mol)
+        AllChem.EmbedMultipleConfs(mol, numConfs=num_conf, numThreads=num_worker, pruneRmsThresh=1, maxAttempts=10000, useRandomCoords=False)
+        try:
+            AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=num_worker)
+        except:
+            pass
+        mol = Chem.RemoveHs(mol)
+    except:
+        print("cannot gen conf", Chem.MolToSmiles(mol))
+        return None
+    if mol.GetNumConformers() == 0:
+        print("cannot gen conf", Chem.MolToSmiles(mol))
+        return None
+    return mol
+
+def convert_2Dmol_to_data(mol, num_conf=1, num_worker=5):
+    #to 3D
+    mol = gen_conformation(mol, num_conf, num_worker)
+    if mol is None:
+        return None
+    coords = [np.array(mol.GetConformer(i).GetPositions()) for i in range(mol.GetNumConformers())]
+    atom_types = [a.GetSymbol() for a in mol.GetAtoms()]
+    return {'coords': coords, 'atom_types': atom_types, 'smi': Chem.MolToSmiles(mol), 'mol': mol}
+
+def convert_3Dmol_to_data(mol):
+
+    if mol is None:
+        return None
+    coords = [np.array(mol.GetConformer(i).GetPositions()) for i in range(mol.GetNumConformers())]
+    atom_types = [a.GetSymbol() for a in mol.GetAtoms()]
+    return {'coords': coords, 'atom_types': atom_types, 'smi': Chem.MolToSmiles(mol), 'mol': mol}
+
+def read_pdb(path):
+    pdb_df = PandasPdb().read_pdb(path)
+
+    coord = pdb_df.df['ATOM'][['x_coord', 'y_coord', 'z_coord']]
+    atom_type = pdb_df.df['ATOM']['atom_name']
+    residue_name = pdb_df.df['ATOM']['chain_id'] + pdb_df.df['ATOM']['residue_number'].astype(str)
+    residue_type = pdb_df.df['ATOM']['residue_name']
+    protein = {'coord': np.array(coord), 
+               'atom_type': list(atom_type),
+               'residue_name': list(residue_name),
+               'residue_type': list(residue_type)}
+    return protein
+
+
+def read_sdf_gz_3d(path):
+    inf = gzip.open(path)
+    with Chem.ForwardSDMolSupplier(inf, removeHs=False, sanitize=False) as gzsuppl:
+        ms = [add_charges(x) for x in gzsuppl if x is not None]
+    ms = [rdMolStandardize.Uncharger().uncharge(Chem.RemoveHs(m)) for m in ms if m is not None]
+    return ms
+
+def add_charges(m):
+    m.UpdatePropertyCache(strict=False)
+    ps = Chem.DetectChemistryProblems(m)
+    if not ps:
+        Chem.SanitizeMol(m)
+        return m
+    for p in ps:
+        if p.GetType()=='AtomValenceException':
+            at = m.GetAtomWithIdx(p.GetAtomIdx())
+            if at.GetAtomicNum()==7 and at.GetFormalCharge()==0 and at.GetExplicitValence()==4:
+                at.SetFormalCharge(1)
+            if at.GetAtomicNum()==6 and at.GetExplicitValence()==5:
+                #remove a bond
+                for b in at.GetBonds():
+                    if b.GetBondType()==Chem.rdchem.BondType.DOUBLE:
+                        b.SetBondType(Chem.rdchem.BondType.SINGLE)
+                        break
+            if at.GetAtomicNum()==8 and at.GetFormalCharge()==0 and at.GetExplicitValence()==3:
+                at.SetFormalCharge(1)
+            if at.GetAtomicNum()==5 and at.GetFormalCharge()==0 and at.GetExplicitValence()==4:
+                at.SetFormalCharge(-1)
+    try:
+        Chem.SanitizeMol(m)
+    except:
+        return None
+    return m
+
+def get_different_raid(protein, ligand, raid=6):
+    protein_coord = protein['coord']
+    ligand_coord = ligand['coord']
+    protein_residue_name = protein['residue_name']
+    pocket_residue = set()
+    for i in range(len(protein_coord)):
+        for j in range(len(ligand_coord)):
+            if np.linalg.norm(protein_coord[i] - ligand_coord[j]) < raid:
+                pocket_residue.add(protein_residue_name[i])
+    return pocket_residue
+
+def read_mol2_ligand(path):
+    mol2_df = PandasMol2().read_mol2(path)
+    coord = mol2_df.df[['x', 'y', 'z']]
+    atom_type = mol2_df.df['atom_name']
+    ligand = {'coord': np.array(coord), 'atom_type': list(atom_type), 'mol': Chem.MolFromMol2File(path)}
+    return ligand
+
+def read_smi_mol(path):
+    with open(path, 'r') as f:
+        mols_lines = list(f.readlines())
+    smis = [l.split(' ')[0] for l in mols_lines]
+    mols = [Chem.MolFromSmiles(m) for m in smis]
+    return mols
+
+def parser(protein_path, mol_path, ligand_path, activity, pocket_index, raid=6):
+    protein = read_pdb(protein_path)
+    data_mols = read_smi_mol(mol_path)
+
+    ligand = read_mol2_ligand(ligand_path)
+    pocket_residue = get_different_raid(protein, ligand, raid=raid)
+    pocket_atom_idx = [i for i, r in enumerate(protein['residue_name']) if r in pocket_residue]
+    pocket_atom_type = [protein['atom_type'][i] for i in pocket_atom_idx]
+    pocket_coord = [protein['coord'][i] for i in pocket_atom_idx]
+    pocket_residue_type = [protein['residue_type'][i] for i in pocket_atom_idx]
+    pocket_name = protein_path.split('/')[-2]
+    pool = mp.Pool(32)
+    #mols = [convert_2Dmol_to_data(m) for m in data_mols if m is not None]
+    data_mols = [m for m in data_mols if m is not None]
+    mols = [m for m in tqdm.tqdm(pool.imap_unordered(convert_2Dmol_to_data, data_mols))]
+    mols = [m for m in mols if m is not None]
+    
+    return [{'atoms': m['atom_types'], 
+            'coordinates': m['coords'], 
+            'smi': m['smi'],
+            'mol': ligand,
+            'pocket_name': pocket_name,
+            'pocket_index': pocket_index,
+            'activity': activity, 
+            "pocket_atom_type": pocket_atom_type, 
+            "pocket_coord": pocket_coord} for m in mols]
+
+def mol_parser(mol_path, ligand_path, label):
+    data_mols = read_smi_mol(mol_path)
+    data_mols = [m for m in data_mols if m is not None]  
+    ligand = read_mol2_ligand(ligand_path)
+    pool = mp.Pool(32) 
+    mols = [m for m in tqdm.tqdm(pool.imap_unordered(convert_2Dmol_to_data, data_mols))]
+    mols = [m for m in mols if m is not None]
+    return [{'atoms': m['atom_types'], 
+            'coordinates': m['coords'], 
+            'smi': m['smi'],
+            'mol': m['mol'],
+            'label': label
+            } for m in mols]
+
+def pocket_parser(protein_path, ligand_path, pocket_index, raid=6):
+    protein = read_pdb(protein_path)
+    ligand = read_mol2_ligand(ligand_path)
+    pocket_residue = get_different_raid(protein, ligand, raid=raid)
+    pocket_atom_idx = [i for i, r in enumerate(protein['residue_name']) if r in pocket_residue]
+    pocket_atom_type = [protein['atom_type'][i] for i in pocket_atom_idx]
+    pocket_coord = [protein['coord'][i] for i in pocket_atom_idx]
+    pocket_residue_type = [protein['residue_type'][i] for i in pocket_atom_idx]
+    pocket_name = protein_path.split('/')[-2]
+    return {'pocket': pocket_name,
+            'pocket_index': pocket_index,
+            "pocket_atoms": pocket_atom_type, 
+            "pocket_coordinates": pocket_coord}
+
+def write_lmdb(data, lmdb_path):
+    #resume
+
+    env = lmdb.open(lmdb_path, subdir=False, readonly=False, lock=False, readahead=False, meminit=False, map_size=1099511627776)
+    num = 0
+    with env.begin(write=True) as txn:
+        for d in data:
+            txn.put(str(num).encode('ascii'), pickle.dumps(d))
+            num += 1
+
+if __name__ == '__main__':
+    protein_path = [os.path.join(args.mol_data_path, x, 'receptor.pdb') for x in os.listdir(args.mol_data_path)]
+    act_mol_path = [os.path.join(args.mol_data_path, x, 'actives_final.ism') for x in os.listdir(args.mol_data_path)]
+    decoy_mol_path = [os.path.join(args.mol_data_path, x, 'decoys_final.ism') for x in os.listdir(args.mol_data_path)]
+    
+    
+    
+    for i, pocket in tqdm.tqdm(enumerate(protein_path)):
+        # acive mols
+        print(i, pocket)
+        data = []
+        d_active = (mol_parser(act_mol_path[i], pocket.replace('receptor.pdb', 'crystal_ligand.mol2'), 1))
+        
+        data.extend(d_active)
+
+        # decoy mols
+        d_decoy = (mol_parser(decoy_mol_path[i], pocket.replace('receptor.pdb', 'crystal_ligand.mol2'), 0))
+        
+        data.extend(d_decoy)
+
+        write_lmdb(data, pocket.replace('receptor.pdb', 'mols.lmdb'))
+
+        # write pocket
+        d = pocket_parser(pocket, pocket.replace('receptor.pdb', 'crystal_ligand.mol2'), i)
+        write_lmdb([d], pocket.replace('receptor.pdb', 'pocket.lmdb'))
+
+        # number of lines in actives_final.smi 
+        with open(act_mol_path[i], 'r') as f:
+            mols_lines = list(f.readlines())
+            print("active", len(d_active), len(mols_lines))
+
+
+       
+        with open(decoy_mol_path[i], 'r') as f:
+            mols_lines = list(f.readlines())
+            print("decoy", len(d_decoy), len(mols_lines))
+        
+