a b/process_crossdock.py
1
from pathlib import Path
2
from time import time
3
import argparse
4
import shutil
5
import random
6
7
import matplotlib.pyplot as plt
8
import seaborn as sns
9
10
from tqdm import tqdm
11
import numpy as np
12
13
from Bio.PDB import PDBParser
14
from Bio.PDB.Polypeptide import three_to_one, is_aa
15
from rdkit import Chem
16
from scipy.ndimage import gaussian_filter
17
18
import torch
19
20
from analysis.molecule_builder import build_molecule
21
from analysis.metrics import rdmol_to_smiles
22
import constants
23
from constants import covalent_radii, dataset_params
24
25
26
def process_ligand_and_pocket(pdbfile, sdffile,
27
                              atom_dict, dist_cutoff, ca_only):
28
    pdb_struct = PDBParser(QUIET=True).get_structure('', pdbfile)
29
30
    try:
31
        ligand = Chem.SDMolSupplier(str(sdffile))[0]
32
    except:
33
        raise Exception(f'cannot read sdf mol ({sdffile})')
34
35
    # remove H atoms if not in atom_dict, other atom types that aren't allowed
36
    # should stay so that the entire ligand can be removed from the dataset
37
    lig_atoms = [a.GetSymbol() for a in ligand.GetAtoms()
38
                 if (a.GetSymbol().capitalize() in atom_dict or a.element != 'H')]
39
    lig_coords = np.array([list(ligand.GetConformer(0).GetAtomPosition(idx))
40
                           for idx in range(ligand.GetNumAtoms())])
41
42
    try:
43
        lig_one_hot = np.stack([
44
            np.eye(1, len(atom_dict), atom_dict[a.capitalize()]).squeeze()
45
            for a in lig_atoms
46
        ])
47
    except KeyError as e:
48
        raise KeyError(
49
            f'{e} not in atom dict ({sdffile})')
50
51
    # Find interacting pocket residues based on distance cutoff
52
    pocket_residues = []
53
    for residue in pdb_struct[0].get_residues():
54
        res_coords = np.array([a.get_coord() for a in residue.get_atoms()])
55
        if is_aa(residue.get_resname(), standard=True) and \
56
                (((res_coords[:, None, :] - lig_coords[None, :, :]) ** 2).sum(
57
                    -1) ** 0.5).min() < dist_cutoff:
58
            pocket_residues.append(residue)
59
60
    pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in pocket_residues]
61
    ligand_data = {
62
        'lig_coords': lig_coords,
63
        'lig_one_hot': lig_one_hot,
64
    }
65
    if ca_only:
66
        try:
67
            pocket_one_hot = []
68
            full_coords = []
69
            for res in pocket_residues:
70
                for atom in res.get_atoms():
71
                    if atom.name == 'CA':
72
                        pocket_one_hot.append(np.eye(1, len(amino_acid_dict),
73
                                                     amino_acid_dict[three_to_one(res.get_resname())]).squeeze())
74
                        full_coords.append(atom.coord)
75
            pocket_one_hot = np.stack(pocket_one_hot)
76
            full_coords = np.stack(full_coords)
77
        except KeyError as e:
78
            raise KeyError(
79
                f'{e} not in amino acid dict ({pdbfile}, {sdffile})')
80
        pocket_data = {
81
            'pocket_coords': full_coords,
82
            'pocket_one_hot': pocket_one_hot,
83
            'pocket_ids': pocket_ids
84
        }
85
    else:
86
        full_atoms = np.concatenate(
87
            [np.array([atom.element for atom in res.get_atoms()])
88
             for res in pocket_residues], axis=0)
89
        full_coords = np.concatenate(
90
            [np.array([atom.coord for atom in res.get_atoms()])
91
             for res in pocket_residues], axis=0)
92
        try:
93
            pocket_one_hot = []
94
            for a in full_atoms:
95
                if a in amino_acid_dict:
96
                    atom = np.eye(1, len(amino_acid_dict),
97
                                  amino_acid_dict[a.capitalize()]).squeeze()
98
                elif a != 'H':
99
                    atom = np.eye(1, len(amino_acid_dict),
100
                                  len(amino_acid_dict)).squeeze()
101
                pocket_one_hot.append(atom)
102
            pocket_one_hot = np.stack(pocket_one_hot)
103
        except KeyError as e:
104
            raise KeyError(
105
                f'{e} not in atom dict ({pdbfile})')
106
        pocket_data = {
107
            'pocket_coords': full_coords,
108
            'pocket_one_hot': pocket_one_hot,
109
            'pocket_ids': pocket_ids
110
        }
111
    return ligand_data, pocket_data
112
113
114
def compute_smiles(positions, one_hot, mask):
115
    print("Computing SMILES ...")
116
117
    atom_types = np.argmax(one_hot, axis=-1)
118
119
    sections = np.where(np.diff(mask))[0] + 1
120
    positions = [torch.from_numpy(x) for x in np.split(positions, sections)]
121
    atom_types = [torch.from_numpy(x) for x in np.split(atom_types, sections)]
122
123
    mols_smiles = []
124
125
    pbar = tqdm(enumerate(zip(positions, atom_types)),
126
                total=len(np.unique(mask)))
127
    for i, (pos, atom_type) in pbar:
128
        mol = build_molecule(pos, atom_type, dataset_info)
129
130
        # BasicMolecularMetrics() computes SMILES after sanitization
131
        try:
132
            Chem.SanitizeMol(mol)
133
        except ValueError:
134
            continue
135
136
        mol = rdmol_to_smiles(mol)
137
        if mol is not None:
138
            mols_smiles.append(mol)
139
        pbar.set_description(f'{len(mols_smiles)}/{i + 1} successful')
140
141
    return mols_smiles
142
143
144
def get_n_nodes(lig_mask, pocket_mask, smooth_sigma=None):
145
    # Joint distribution of ligand's and pocket's number of nodes
146
    idx_lig, n_nodes_lig = np.unique(lig_mask, return_counts=True)
147
    idx_pocket, n_nodes_pocket = np.unique(pocket_mask, return_counts=True)
148
    assert np.all(idx_lig == idx_pocket)
149
150
    joint_histogram = np.zeros((np.max(n_nodes_lig) + 1,
151
                                np.max(n_nodes_pocket) + 1))
152
153
    for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket):
154
        joint_histogram[nlig, npocket] += 1
155
156
    print(f'Original histogram: {np.count_nonzero(joint_histogram)}/'
157
          f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled')
158
159
    # Smooth the histogram
160
    if smooth_sigma is not None:
161
        filtered_histogram = gaussian_filter(
162
            joint_histogram, sigma=smooth_sigma, order=0, mode='constant',
163
            cval=0.0, truncate=4.0)
164
165
        print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/'
166
              f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled')
167
168
        joint_histogram = filtered_histogram
169
170
    return joint_histogram
171
172
173
def get_bond_length_arrays(atom_mapping):
174
    bond_arrays = []
175
    for i in range(3):
176
        bond_dict = getattr(constants, f'bonds{i + 1}')
177
        bond_array = np.zeros((len(atom_mapping), len(atom_mapping)))
178
        for a1 in atom_mapping.keys():
179
            for a2 in atom_mapping.keys():
180
                if a1 in bond_dict and a2 in bond_dict[a1]:
181
                    bond_len = bond_dict[a1][a2]
182
                else:
183
                    bond_len = 0
184
                bond_array[atom_mapping[a1], atom_mapping[a2]] = bond_len
185
186
        assert np.all(bond_array == bond_array.T)
187
        bond_arrays.append(bond_array)
188
189
    return bond_arrays
190
191
192
def get_lennard_jones_rm(atom_mapping):
193
    # Bond radii for the Lennard-Jones potential
194
    LJ_rm = np.zeros((len(atom_mapping), len(atom_mapping)))
195
196
    for a1 in atom_mapping.keys():
197
        for a2 in atom_mapping.keys():
198
            all_bond_lengths = []
199
            for btype in ['bonds1', 'bonds2', 'bonds3']:
200
                bond_dict = getattr(constants, btype)
201
                if a1 in bond_dict and a2 in bond_dict[a1]:
202
                    all_bond_lengths.append(bond_dict[a1][a2])
203
204
            if len(all_bond_lengths) > 0:
205
                # take the shortest possible bond length because slightly larger
206
                # values aren't penalized as much
207
                bond_len = min(all_bond_lengths)
208
            else:
209
                if a1 == 'others' or a2 == 'others':
210
                    bond_len = 0
211
                else:
212
                    # Replace missing values with sum of average covalent radii
213
                    bond_len = covalent_radii[a1] + covalent_radii[a2]
214
215
            LJ_rm[atom_mapping[a1], atom_mapping[a2]] = bond_len
216
217
    assert np.all(LJ_rm == LJ_rm.T)
218
    return LJ_rm
219
220
221
def get_type_histograms(lig_one_hot, pocket_one_hot, atom_encoder, aa_encoder):
222
    atom_decoder = list(atom_encoder.keys())
223
    atom_counts = {k: 0 for k in atom_encoder.keys()}
224
    for a in [atom_decoder[x] for x in lig_one_hot.argmax(1)]:
225
        atom_counts[a] += 1
226
227
    aa_decoder = list(aa_encoder.keys())
228
    aa_counts = {k: 0 for k in aa_encoder.keys()}
229
    for r in [aa_decoder[x] for x in pocket_one_hot.argmax(1)]:
230
        aa_counts[r] += 1
231
232
    return atom_counts, aa_counts
233
234
235
def saveall(filename, pdb_and_mol_ids, lig_coords, lig_one_hot, lig_mask,
236
            pocket_coords, pocket_one_hot, pocket_mask):
237
    np.savez(filename,
238
             names=pdb_and_mol_ids,
239
             lig_coords=lig_coords,
240
             lig_one_hot=lig_one_hot,
241
             lig_mask=lig_mask,
242
             pocket_coords=pocket_coords,
243
             pocket_one_hot=pocket_one_hot,
244
             pocket_mask=pocket_mask
245
             )
246
    return True
247
248
249
if __name__ == '__main__':
250
    parser = argparse.ArgumentParser()
251
    parser.add_argument('basedir', type=Path)
252
    parser.add_argument('--outdir', type=Path, default=None)
253
    parser.add_argument('--no_H', action='store_true')
254
    parser.add_argument('--ca_only', action='store_true')
255
    parser.add_argument('--dist_cutoff', type=float, default=8.0)
256
    parser.add_argument('--random_seed', type=int, default=42)
257
    args = parser.parse_args()
258
259
    datadir = args.basedir / 'crossdocked_pocket10/'
260
261
    if args.ca_only:
262
        dataset_info = dataset_params['crossdock']
263
    else:
264
        dataset_info = dataset_params['crossdock_full']
265
    amino_acid_dict = dataset_info['aa_encoder']
266
    atom_dict = dataset_info['atom_encoder']
267
    atom_decoder = dataset_info['atom_decoder']
268
269
    # Make output directory
270
    if args.outdir is None:
271
        suffix = '_crossdock' if 'H' in atom_dict else '_crossdock_noH'
272
        suffix += '_ca_only_temp' if args.ca_only else '_full_temp'
273
        processed_dir = Path(args.basedir, f'processed{suffix}')
274
    else:
275
        processed_dir = args.outdir
276
277
    processed_dir.mkdir(exist_ok=True, parents=True)
278
279
    # Read data split
280
    split_path = Path(args.basedir, 'split_by_name.pt')
281
    data_split = torch.load(split_path)
282
283
    # There is no validation set, copy 300 training examples (the validation set
284
    # is not very important in this application)
285
    # Note: before we had a data leak but it should not matter too much as most
286
    # metrics monitored during training are independent of the pockets
287
    data_split['val'] = random.sample(data_split['train'], 300)
288
289
    n_train_before = len(data_split['train'])
290
    n_val_before = len(data_split['val'])
291
    n_test_before = len(data_split['test'])
292
293
    failed_save = []
294
295
    n_samples_after = {}
296
    for split in data_split.keys():
297
        lig_coords = []
298
        lig_one_hot = []
299
        lig_mask = []
300
        pocket_coords = []
301
        pocket_one_hot = []
302
        pocket_mask = []
303
        pdb_and_mol_ids = []
304
        count_protein = []
305
        count_ligand = []
306
        count_total = []
307
        count = 0
308
309
        pdb_sdf_dir = processed_dir / split
310
        pdb_sdf_dir.mkdir(exist_ok=True)
311
312
        tic = time()
313
        num_failed = 0
314
        pbar = tqdm(data_split[split])
315
        pbar.set_description(f'#failed: {num_failed}')
316
        for pocket_fn, ligand_fn in pbar:
317
318
            sdffile = datadir / f'{ligand_fn}'
319
            pdbfile = datadir / f'{pocket_fn}'
320
321
            try:
322
                struct_copy = PDBParser(QUIET=True).get_structure('', pdbfile)
323
            except:
324
                num_failed += 1
325
                failed_save.append((pocket_fn, ligand_fn))
326
                print(failed_save[-1])
327
                pbar.set_description(f'#failed: {num_failed}')
328
                continue
329
330
            try:
331
                ligand_data, pocket_data = process_ligand_and_pocket(
332
                    pdbfile, sdffile,
333
                    atom_dict=atom_dict, dist_cutoff=args.dist_cutoff,
334
                    ca_only=args.ca_only)
335
            except (KeyError, AssertionError, FileNotFoundError, IndexError,
336
                    ValueError) as e:
337
                print(type(e).__name__, e, pocket_fn, ligand_fn)
338
                num_failed += 1
339
                pbar.set_description(f'#failed: {num_failed}')
340
                continue
341
342
            pdb_and_mol_ids.append(f"{pocket_fn}_{ligand_fn}")
343
            lig_coords.append(ligand_data['lig_coords'])
344
            lig_one_hot.append(ligand_data['lig_one_hot'])
345
            lig_mask.append(count * np.ones(len(ligand_data['lig_coords'])))
346
            pocket_coords.append(pocket_data['pocket_coords'])
347
            pocket_one_hot.append(pocket_data['pocket_one_hot'])
348
            pocket_mask.append(
349
                count * np.ones(len(pocket_data['pocket_coords'])))
350
            count_protein.append(pocket_data['pocket_coords'].shape[0])
351
            count_ligand.append(ligand_data['lig_coords'].shape[0])
352
            count_total.append(pocket_data['pocket_coords'].shape[0] +
353
                               ligand_data['lig_coords'].shape[0])
354
            count += 1
355
356
            if split in {'val', 'test'}:
357
                # Copy PDB file
358
                new_rec_name = Path(pdbfile).stem.replace('_', '-')
359
                pdb_file_out = Path(pdb_sdf_dir, f"{new_rec_name}.pdb")
360
                shutil.copy(pdbfile, pdb_file_out)
361
362
                # Copy SDF file
363
                new_lig_name = new_rec_name + '_' + Path(sdffile).stem.replace('_', '-')
364
                sdf_file_out = Path(pdb_sdf_dir, f'{new_lig_name}.sdf')
365
                shutil.copy(sdffile, sdf_file_out)
366
367
                # specify pocket residues
368
                with open(Path(pdb_sdf_dir, f'{new_lig_name}.txt'), 'w') as f:
369
                    f.write(' '.join(pocket_data['pocket_ids']))
370
371
        lig_coords = np.concatenate(lig_coords, axis=0)
372
        lig_one_hot = np.concatenate(lig_one_hot, axis=0)
373
        lig_mask = np.concatenate(lig_mask, axis=0)
374
        pocket_coords = np.concatenate(pocket_coords, axis=0)
375
        pocket_one_hot = np.concatenate(pocket_one_hot, axis=0)
376
        pocket_mask = np.concatenate(pocket_mask, axis=0)
377
378
        saveall(processed_dir / f'{split}.npz', pdb_and_mol_ids, lig_coords,
379
                lig_one_hot, lig_mask, pocket_coords,
380
                pocket_one_hot, pocket_mask)
381
382
        n_samples_after[split] = len(pdb_and_mol_ids)
383
        print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes")
384
385
    # --------------------------------------------------------------------------
386
    # Compute statistics & additional information
387
    # --------------------------------------------------------------------------
388
    with np.load(processed_dir / 'train.npz', allow_pickle=True) as data:
389
        lig_mask = data['lig_mask']
390
        pocket_mask = data['pocket_mask']
391
        lig_coords = data['lig_coords']
392
        lig_one_hot = data['lig_one_hot']
393
        pocket_one_hot = data['pocket_one_hot']
394
395
    # Compute SMILES for all training examples
396
    train_smiles = compute_smiles(lig_coords, lig_one_hot, lig_mask)
397
    np.save(processed_dir / 'train_smiles.npy', train_smiles)
398
399
    # Joint histogram of number of ligand and pocket nodes
400
    n_nodes = get_n_nodes(lig_mask, pocket_mask, smooth_sigma=1.0)
401
    np.save(Path(processed_dir, 'size_distribution.npy'), n_nodes)
402
403
    # Convert bond length dictionaries to arrays for batch processing
404
    bonds1, bonds2, bonds3 = get_bond_length_arrays(atom_dict)
405
406
    # Get bond length definitions for Lennard-Jones potential
407
    rm_LJ = get_lennard_jones_rm(atom_dict)
408
409
    # Get histograms of ligand and pocket node types
410
    atom_hist, aa_hist = get_type_histograms(lig_one_hot, pocket_one_hot,
411
                                             atom_dict, amino_acid_dict)
412
413
    # Create summary string
414
    summary_string = '# SUMMARY\n\n'
415
    summary_string += '# Before processing\n'
416
    summary_string += f'num_samples train: {n_train_before}\n'
417
    summary_string += f'num_samples val: {n_val_before}\n'
418
    summary_string += f'num_samples test: {n_test_before}\n\n'
419
    summary_string += '# After processing\n'
420
    summary_string += f"num_samples train: {n_samples_after['train']}\n"
421
    summary_string += f"num_samples val: {n_samples_after['val']}\n"
422
    summary_string += f"num_samples test: {n_samples_after['test']}\n\n"
423
    summary_string += '# Info\n'
424
    summary_string += f"'atom_encoder': {atom_dict}\n"
425
    summary_string += f"'atom_decoder': {list(atom_dict.keys())}\n"
426
    summary_string += f"'aa_encoder': {amino_acid_dict}\n"
427
    summary_string += f"'aa_decoder': {list(amino_acid_dict.keys())}\n"
428
    summary_string += f"'bonds1': {bonds1.tolist()}\n"
429
    summary_string += f"'bonds2': {bonds2.tolist()}\n"
430
    summary_string += f"'bonds3': {bonds3.tolist()}\n"
431
    summary_string += f"'lennard_jones_rm': {rm_LJ.tolist()}\n"
432
    summary_string += f"'atom_hist': {atom_hist}\n"
433
    summary_string += f"'aa_hist': {aa_hist}\n"
434
    summary_string += f"'n_nodes': {n_nodes.tolist()}\n"
435
436
    # Write summary to text file
437
    with open(processed_dir / 'summary.txt', 'w') as f:
438
        f.write(summary_string)
439
440
    # Print summary
441
    print(summary_string)
442
443
    print(failed_save)