|
a |
|
b/data/prepare_crossdocked.py |
|
|
1 |
import sys |
|
|
2 |
import torch |
|
|
3 |
import shutil |
|
|
4 |
from pathlib import Path |
|
|
5 |
|
|
|
6 |
from rdkit import Chem |
|
|
7 |
from tqdm import tqdm |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
basedir = sys.argv[1] |
|
|
11 |
structure_dir = Path(basedir, 'crossdocked_pocket10') |
|
|
12 |
|
|
|
13 |
test_set = torch.load(Path(basedir, 'split_by_name.pt'))['test'] |
|
|
14 |
|
|
|
15 |
receptor_dir = Path(basedir, 'receptor_pdbs') |
|
|
16 |
receptor_dir.mkdir(exist_ok=True) |
|
|
17 |
|
|
|
18 |
ref_ligand_dir = Path(basedir, 'reference_ligands') |
|
|
19 |
ref_ligand_dir.mkdir(exist_ok=True) |
|
|
20 |
|
|
|
21 |
methods = ['cvae', 'sbdd', 'p2m'] |
|
|
22 |
for method in methods: |
|
|
23 |
method_lig_dir = Path(basedir, f'{method}_processed') |
|
|
24 |
method_lig_dir.mkdir(exist_ok=True) |
|
|
25 |
|
|
|
26 |
for pocket_idx, (receptor_name, ligand_name) in enumerate(tqdm(test_set)): |
|
|
27 |
|
|
|
28 |
# copy receptor file and remove underscores |
|
|
29 |
new_rec_name = Path(receptor_name).stem.replace('_', '-') |
|
|
30 |
shutil.copy(Path(structure_dir, receptor_name), Path(receptor_dir, new_rec_name + '.pdb')) |
|
|
31 |
|
|
|
32 |
# copy and rename reference ligands |
|
|
33 |
new_lig_name = new_rec_name + '_' + Path(ligand_name).stem.replace('_', '-') |
|
|
34 |
shutil.copy(Path(structure_dir, ligand_name), Path(ref_ligand_dir, new_lig_name + '.sdf')) |
|
|
35 |
|
|
|
36 |
for method in methods: |
|
|
37 |
|
|
|
38 |
method_pocket_dir = Path(basedir, method, f'pocket_{pocket_idx}') |
|
|
39 |
|
|
|
40 |
generated_mols = [Chem.SDMolSupplier(str(file), sanitize=False)[0] |
|
|
41 |
for file in method_pocket_dir.glob(f'mol_*.sdf')] |
|
|
42 |
|
|
|
43 |
# only select first 100 molecules |
|
|
44 |
generated_mols = generated_mols[:100] |
|
|
45 |
if len(generated_mols) < 1: |
|
|
46 |
print('No molecule found for this pocket') |
|
|
47 |
continue |
|
|
48 |
if len(generated_mols) < 100: |
|
|
49 |
print('Less than 100 molecules found for this pocket') |
|
|
50 |
|
|
|
51 |
# write a combined sdf file |
|
|
52 |
sdf_path = Path(basedir, f'{method}_processed', f'{new_rec_name}_mols-pocket-{pocket_idx}.sdf') |
|
|
53 |
with Chem.SDWriter(str(sdf_path)) as w: |
|
|
54 |
for mol in generated_mols: |
|
|
55 |
w.write(mol) |