Switch to unified view

a b/data/prepare_crossdocked.py
1
import sys
2
import torch
3
import shutil
4
from pathlib import Path
5
6
from rdkit import Chem
7
from tqdm import tqdm
8
9
10
basedir = sys.argv[1]
11
structure_dir = Path(basedir, 'crossdocked_pocket10')
12
13
test_set = torch.load(Path(basedir, 'split_by_name.pt'))['test']
14
15
receptor_dir = Path(basedir, 'receptor_pdbs')
16
receptor_dir.mkdir(exist_ok=True)
17
18
ref_ligand_dir = Path(basedir, 'reference_ligands')
19
ref_ligand_dir.mkdir(exist_ok=True)
20
21
methods = ['cvae', 'sbdd', 'p2m']
22
for method in methods:
23
    method_lig_dir = Path(basedir, f'{method}_processed')
24
    method_lig_dir.mkdir(exist_ok=True)
25
26
for pocket_idx, (receptor_name, ligand_name) in enumerate(tqdm(test_set)):
27
28
    # copy receptor file and remove underscores
29
    new_rec_name = Path(receptor_name).stem.replace('_', '-')
30
    shutil.copy(Path(structure_dir, receptor_name), Path(receptor_dir, new_rec_name + '.pdb'))
31
32
    # copy and rename reference ligands
33
    new_lig_name = new_rec_name + '_' + Path(ligand_name).stem.replace('_', '-')
34
    shutil.copy(Path(structure_dir, ligand_name), Path(ref_ligand_dir, new_lig_name + '.sdf'))
35
36
    for method in methods:
37
38
        method_pocket_dir = Path(basedir, method, f'pocket_{pocket_idx}')
39
40
        generated_mols = [Chem.SDMolSupplier(str(file), sanitize=False)[0]
41
                          for file in method_pocket_dir.glob(f'mol_*.sdf')]
42
43
        # only select first 100 molecules
44
        generated_mols = generated_mols[:100]
45
        if len(generated_mols) < 1:
46
            print('No molecule found for this pocket')
47
            continue
48
        if len(generated_mols) < 100:
49
            print('Less than 100 molecules found for this pocket')
50
51
        # write a combined sdf file
52
        sdf_path = Path(basedir, f'{method}_processed', f'{new_rec_name}_mols-pocket-{pocket_idx}.sdf')
53
        with Chem.SDWriter(str(sdf_path)) as w:
54
            for mol in generated_mols:
55
                w.write(mol)