Diff of /process_bindingmoad.py [000000] .. [607087]

Switch to unified view

a b/process_bindingmoad.py
1
from pathlib import Path
2
from time import time
3
import random
4
from collections import defaultdict
5
import argparse
6
import warnings
7
8
from tqdm import tqdm
9
import numpy as np
10
import torch
11
from Bio.PDB import PDBParser
12
from Bio.PDB.Polypeptide import three_to_one, is_aa
13
from Bio.PDB import PDBIO, Select
14
from openbabel import openbabel
15
from rdkit import Chem
16
from rdkit.Chem import QED
17
from scipy.ndimage import gaussian_filter
18
19
from geometry_utils import get_bb_transform
20
from analysis.molecule_builder import build_molecule
21
from analysis.metrics import rdmol_to_smiles
22
import constants
23
from constants import covalent_radii, dataset_params
24
import utils
25
26
dataset_info = dataset_params['bindingmoad']
27
amino_acid_dict = dataset_info['aa_encoder']
28
atom_dict = dataset_info['atom_encoder']
29
atom_decoder = dataset_info['atom_decoder']
30
31
32
class Model0(Select):
33
    def accept_model(self, model):
34
        return model.id == 0
35
36
37
def read_label_file(csv_path):
38
    """
39
    Read BindingMOAD's label file
40
    Args:
41
        csv_path: path to 'every.csv'
42
    Returns:
43
        Nested dictionary with all ligands. First level: EC number,
44
            Second level: PDB ID, Third level: list of ligands. Each ligand is
45
            represented as a tuple (ligand name, validity, SMILES string)
46
    """
47
    ligand_dict = {}
48
49
    with open(csv_path, 'r') as f:
50
        for line in f.readlines():
51
            row = line.split(',')
52
53
            # new protein class
54
            if len(row[0]) > 0:
55
                curr_class = row[0]
56
                ligand_dict[curr_class] = {}
57
                continue
58
59
            # new protein
60
            if len(row[2]) > 0:
61
                curr_prot = row[2]
62
                ligand_dict[curr_class][curr_prot] = []
63
                continue
64
65
            # new small molecule
66
            if len(row[3]) > 0:
67
                ligand_dict[curr_class][curr_prot].append(
68
                    # (ligand name, validity, SMILES string)
69
                    [row[3], row[4], row[9]]
70
                )
71
72
    return ligand_dict
73
74
75
def compute_druglikeness(ligand_dict):
76
    """
77
    Computes RDKit's QED value and adds it to the dictionary
78
    Args:
79
        ligand_dict: nested ligand dictionary
80
    Returns:
81
        the same ligand dictionary with additional QED values
82
    """
83
    print("Computing QED values...")
84
    for p, m in tqdm([(p, m) for c in ligand_dict for p in ligand_dict[c]
85
                      for m in ligand_dict[c][p]]):
86
        mol = Chem.MolFromSmiles(m[2])
87
        if mol is None:
88
            mol_id = f'{p}_{m}'
89
            warnings.warn(f"Could not construct molecule {mol_id} from SMILES "
90
                          f"string '{m[2]}'")
91
            continue
92
        m.append(QED.qed(mol))
93
    return ligand_dict
94
95
96
def filter_and_flatten(ligand_dict, qed_thresh, max_occurences, seed):
97
98
    filtered_examples = []
99
    all_examples = [(c, p, m) for c in ligand_dict for p in ligand_dict[c]
100
                    for m in ligand_dict[c][p]]
101
102
    # shuffle to select random examples of ligands that occur more than
103
    # max_occurences times
104
    random.seed(seed)
105
    random.shuffle(all_examples)
106
107
    ligand_name_counter = defaultdict(int)
108
    print("Filtering examples...")
109
    for c, p, m in tqdm(all_examples):
110
111
        ligand_name, ligand_chain, ligand_resi = m[0].split(':')
112
        if m[1] == 'valid' and len(m) > 3 and m[3] > qed_thresh:
113
            if ligand_name_counter[ligand_name] < max_occurences:
114
                filtered_examples.append(
115
                    (c, p, m)
116
                )
117
                ligand_name_counter[ligand_name] += 1
118
119
    return filtered_examples
120
121
122
def split_by_ec_number(data_list, n_val, n_test, ec_level=1):
123
    """
124
    Split dataset into training, validation and test sets based on EC numbers
125
    https://en.wikipedia.org/wiki/Enzyme_Commission_number
126
    Args:
127
        data_list: list of ligands
128
        n_val: number of validation examples
129
        n_test: number of test examples
130
        ec_level: level in the EC numbering hierarchy at which the split is
131
            made, i.e. items with matching EC numbers at this level are put in
132
            the same set
133
    Returns:
134
        dictionary with keys 'train', 'val', and 'test'
135
    """
136
137
    examples_per_class = defaultdict(int)
138
    for c, p, m in data_list:
139
        c_sub = '.'.join(c.split('.')[:ec_level])
140
        examples_per_class[c_sub] += 1
141
142
    assert sum(examples_per_class.values()) == len(data_list)
143
144
    # split ec numbers
145
    val_classes = set()
146
    for c, num in sorted(examples_per_class.items(), key=lambda x: x[1],
147
                         reverse=True):
148
        if sum([examples_per_class[x] for x in val_classes]) + num <= n_val:
149
            val_classes.add(c)
150
151
    test_classes = set()
152
    for c, num in sorted(examples_per_class.items(), key=lambda x: x[1],
153
                         reverse=True):
154
        # skip classes already used in the validation set
155
        if c in val_classes:
156
            continue
157
        if sum([examples_per_class[x] for x in test_classes]) + num <= n_test:
158
            test_classes.add(c)
159
160
    # remaining classes belong to test set
161
    train_classes = {x for x in examples_per_class if
162
                     x not in val_classes and x not in test_classes}
163
164
    # create separate lists of examples
165
    data_split = {}
166
    data_split['train'] = [x for x in data_list if '.'.join(
167
        x[0].split('.')[:ec_level]) in train_classes]
168
    data_split['val'] = [x for x in data_list if '.'.join(
169
        x[0].split('.')[:ec_level]) in val_classes]
170
    data_split['test'] = [x for x in data_list if '.'.join(
171
        x[0].split('.')[:ec_level]) in test_classes]
172
173
    assert len(data_split['train']) + len(data_split['val']) + \
174
           len(data_split['test']) == len(data_list)
175
176
    return data_split
177
178
179
def ligand_list_to_dict(ligand_list):
180
    out_dict = defaultdict(list)
181
    for _, p, m in ligand_list:
182
        out_dict[p].append(m)
183
    return out_dict
184
185
186
def process_ligand_and_pocket(pdb_struct, ligand_name, ligand_chain,
187
                              ligand_resi, dist_cutoff, ca_only,
188
                              compute_quaternion=False):
189
    try:
190
        residues = {obj.id[1]: obj for obj in
191
                    pdb_struct[0][ligand_chain].get_residues()}
192
    except KeyError as e:
193
        raise KeyError(f'Chain {e} not found ({pdbfile}, '
194
                       f'{ligand_name}:{ligand_chain}:{ligand_resi})')
195
    ligand = residues[ligand_resi]
196
    assert ligand.get_resname() == ligand_name, \
197
        f"{ligand.get_resname()} != {ligand_name}"
198
199
    # remove H atoms if not in atom_dict, other atom types that aren't allowed
200
    # should stay so that the entire ligand can be removed from the dataset
201
    lig_atoms = [a for a in ligand.get_atoms()
202
                 if (a.element.capitalize() in atom_dict or a.element != 'H')]
203
    lig_coords = np.array([a.get_coord() for a in lig_atoms])
204
205
    try:
206
        lig_one_hot = np.stack([
207
            np.eye(1, len(atom_dict), atom_dict[a.element.capitalize()]).squeeze()
208
            for a in lig_atoms
209
        ])
210
    except KeyError as e:
211
        raise KeyError(
212
            f'Ligand atom {e} not in atom dict ({pdbfile}, '
213
            f'{ligand_name}:{ligand_chain}:{ligand_resi})')
214
215
    # Find interacting pocket residues based on distance cutoff
216
    pocket_residues = []
217
    for residue in pdb_struct[0].get_residues():
218
        res_coords = np.array([a.get_coord() for a in residue.get_atoms()])
219
        if is_aa(residue.get_resname(), standard=True) and \
220
                (((res_coords[:, None, :] - lig_coords[None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff:
221
            pocket_residues.append(residue)
222
223
    # Compute transform of the canonical reference frame
224
    n_xyz = np.array([res['N'].get_coord() for res in pocket_residues])
225
    ca_xyz = np.array([res['CA'].get_coord() for res in pocket_residues])
226
    c_xyz = np.array([res['C'].get_coord() for res in pocket_residues])
227
228
    if compute_quaternion:
229
        quaternion, c_alpha = get_bb_transform(n_xyz, ca_xyz, c_xyz)
230
        if np.any(np.isnan(quaternion)):
231
            raise ValueError(
232
                f'Invalid value in quaternion ({pdbfile}, '
233
                f'{ligand_name}:{ligand_chain}:{ligand_resi})')
234
    else:
235
        c_alpha = ca_xyz
236
237
    if ca_only:
238
        pocket_coords = c_alpha
239
        try:
240
            pocket_one_hot = np.stack([
241
                np.eye(1, len(amino_acid_dict),
242
                       amino_acid_dict[three_to_one(res.get_resname())]).squeeze()
243
                for res in pocket_residues])
244
        except KeyError as e:
245
            raise KeyError(
246
                f'{e} not in amino acid dict ({pdbfile}, '
247
                f'{ligand_name}:{ligand_chain}:{ligand_resi})')
248
    else:
249
        pocket_atoms = [a for res in pocket_residues for a in res.get_atoms()
250
                        if (a.element.capitalize() in atom_dict or a.element != 'H')]
251
        pocket_coords = np.array([a.get_coord() for a in pocket_atoms])
252
        try:
253
            pocket_one_hot = np.stack([
254
                np.eye(1, len(atom_dict), atom_dict[a.element.capitalize()]).squeeze()
255
                for a in pocket_atoms
256
            ])
257
        except KeyError as e:
258
            raise KeyError(
259
                f'Pocket atom {e} not in atom dict ({pdbfile}, '
260
                f'{ligand_name}:{ligand_chain}:{ligand_resi})')
261
262
    pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in pocket_residues]
263
264
    ligand_data = {
265
        'lig_coords': lig_coords,
266
        'lig_one_hot': lig_one_hot,
267
    }
268
    pocket_data = {
269
        'pocket_coords': pocket_coords,
270
        'pocket_one_hot': pocket_one_hot,
271
        'pocket_ids': pocket_ids,
272
    }
273
    if compute_quaternion:
274
        pocket_data['pocket_quaternion'] = quaternion
275
    return ligand_data, pocket_data
276
277
278
def compute_smiles(positions, one_hot, mask):
279
    print("Computing SMILES ...")
280
281
    atom_types = np.argmax(one_hot, axis=-1)
282
283
    sections = np.where(np.diff(mask))[0] + 1
284
    positions = [torch.from_numpy(x) for x in np.split(positions, sections)]
285
    atom_types = [torch.from_numpy(x) for x in np.split(atom_types, sections)]
286
287
    mols_smiles = []
288
289
    pbar = tqdm(enumerate(zip(positions, atom_types)),
290
                total=len(np.unique(mask)))
291
    for i, (pos, atom_type) in pbar:
292
        mol = build_molecule(pos, atom_type, dataset_info)
293
294
        # BasicMolecularMetrics() computes SMILES after sanitization
295
        try:
296
            Chem.SanitizeMol(mol)
297
        except ValueError:
298
            continue
299
300
        mol = rdmol_to_smiles(mol)
301
        if mol is not None:
302
            mols_smiles.append(mol)
303
        pbar.set_description(f'{len(mols_smiles)}/{i + 1} successful')
304
305
    return mols_smiles
306
307
308
def get_n_nodes(lig_mask, pocket_mask, smooth_sigma=None):
309
    # Joint distribution of ligand's and pocket's number of nodes
310
    idx_lig, n_nodes_lig = np.unique(lig_mask, return_counts=True)
311
    idx_pocket, n_nodes_pocket = np.unique(pocket_mask, return_counts=True)
312
    assert np.all(idx_lig == idx_pocket)
313
314
    joint_histogram = np.zeros((np.max(n_nodes_lig) + 1,
315
                                np.max(n_nodes_pocket) + 1))
316
317
    for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket):
318
        joint_histogram[nlig, npocket] += 1
319
320
    print(f'Original histogram: {np.count_nonzero(joint_histogram)}/'
321
          f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled')
322
323
    # Smooth the histogram
324
    if smooth_sigma is not None:
325
        filtered_histogram = gaussian_filter(
326
            joint_histogram, sigma=smooth_sigma, order=0, mode='constant',
327
            cval=0.0, truncate=4.0)
328
329
        print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/'
330
              f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled')
331
332
        joint_histogram = filtered_histogram
333
334
    return joint_histogram
335
336
337
def get_bond_length_arrays(atom_mapping):
338
    bond_arrays = []
339
    for i in range(3):
340
        bond_dict = getattr(constants, f'bonds{i + 1}')
341
        bond_array = np.zeros((len(atom_mapping), len(atom_mapping)))
342
        for a1 in atom_mapping.keys():
343
            for a2 in atom_mapping.keys():
344
                if a1 in bond_dict and a2 in bond_dict[a1]:
345
                    bond_len = bond_dict[a1][a2]
346
                else:
347
                    bond_len = 0
348
                bond_array[atom_mapping[a1], atom_mapping[a2]] = bond_len
349
350
        assert np.all(bond_array == bond_array.T)
351
        bond_arrays.append(bond_array)
352
353
    return bond_arrays
354
355
356
def get_lennard_jones_rm(atom_mapping):
357
    # Bond radii for the Lennard-Jones potential
358
    LJ_rm = np.zeros((len(atom_mapping), len(atom_mapping)))
359
360
    for a1 in atom_mapping.keys():
361
        for a2 in atom_mapping.keys():
362
            all_bond_lengths = []
363
            for btype in ['bonds1', 'bonds2', 'bonds3']:
364
                bond_dict = getattr(constants, btype)
365
                if a1 in bond_dict and a2 in bond_dict[a1]:
366
                    all_bond_lengths.append(bond_dict[a1][a2])
367
368
            if len(all_bond_lengths) > 0:
369
                # take the shortest possible bond length because slightly larger
370
                # values aren't penalized as much
371
                bond_len = min(all_bond_lengths)
372
            else:
373
                # Replace missing values with sum of average covalent radii
374
                bond_len = covalent_radii[a1] + covalent_radii[a2]
375
376
            LJ_rm[atom_mapping[a1], atom_mapping[a2]] = bond_len
377
378
    assert np.all(LJ_rm == LJ_rm.T)
379
    return LJ_rm
380
381
382
def get_type_histograms(lig_one_hot, pocket_one_hot, atom_encoder, aa_encoder):
383
384
    atom_decoder = list(atom_encoder.keys())
385
    atom_counts = {k: 0 for k in atom_encoder.keys()}
386
    for a in [atom_decoder[x] for x in lig_one_hot.argmax(1)]:
387
        atom_counts[a] += 1
388
389
    aa_decoder = list(aa_encoder.keys())
390
    aa_counts = {k: 0 for k in aa_encoder.keys()}
391
    for r in [aa_decoder[x] for x in pocket_one_hot.argmax(1)]:
392
        aa_counts[r] += 1
393
394
    return atom_counts, aa_counts
395
396
397
def saveall(filename, pdb_and_mol_ids, lig_coords, lig_one_hot, lig_mask,
398
            pocket_coords, pocket_quaternion, pocket_one_hot, pocket_mask):
399
400
    np.savez(filename,
401
        names=pdb_and_mol_ids,
402
        lig_coords=lig_coords,
403
        lig_one_hot=lig_one_hot,
404
        lig_mask=lig_mask,
405
        pocket_coords=pocket_coords,
406
        pocket_quaternion=pocket_quaternion,
407
        pocket_one_hot=pocket_one_hot,
408
        pocket_mask=pocket_mask
409
    )
410
    return True
411
412
413
if __name__ == '__main__':
414
    parser = argparse.ArgumentParser()
415
    parser.add_argument('basedir', type=Path)
416
    parser.add_argument('--outdir', type=Path, default=None)
417
    parser.add_argument('--qed_thresh', type=float, default=0.3)
418
    parser.add_argument('--max_occurences', type=int, default=50)
419
    parser.add_argument('--num_val', type=int, default=300)
420
    parser.add_argument('--num_test', type=int, default=300)
421
    parser.add_argument('--dist_cutoff', type=float, default=8.0)
422
    parser.add_argument('--ca_only', action='store_true')
423
    parser.add_argument('--random_seed', type=int, default=42)
424
    parser.add_argument('--make_split', action='store_true')
425
    args = parser.parse_args()
426
427
    pdbdir = args.basedir / 'BindingMOAD_2020/'
428
429
    # Make output directory
430
    if args.outdir is None:
431
        suffix = '' if 'H' in atom_dict else '_noH'
432
        suffix += '_ca_only' if args.ca_only else '_full'
433
        processed_dir = Path(args.basedir, f'processed{suffix}')
434
    else:
435
        processed_dir = args.outdir
436
437
    processed_dir.mkdir(exist_ok=True, parents=True)
438
439
    if args.make_split:
440
        # Process the label file
441
        csv_path = args.basedir / 'every.csv'
442
        ligand_dict = read_label_file(csv_path)
443
        ligand_dict = compute_druglikeness(ligand_dict)
444
        filtered_examples = filter_and_flatten(
445
            ligand_dict, args.qed_thresh, args.max_occurences, args.random_seed)
446
        print(f'{len(filtered_examples)} examples after filtering')
447
448
        # Make data split
449
        data_split = split_by_ec_number(filtered_examples, args.num_val,
450
                                        args.num_test)
451
452
    else:
453
        # Use precomputed data split
454
        data_split = {}
455
        for split in ['test', 'val', 'train']:
456
            with open(f'data/moad_{split}.txt', 'r') as f:
457
                pocket_ids = f.read().split(',')
458
            # (ec-number, protein, molecule tuple)
459
            data_split[split] = [(None, x.split('_')[0][:4], (x.split('_')[1],))
460
                          for x in pocket_ids]
461
462
    n_train_before = len(data_split['train'])
463
    n_val_before = len(data_split['val'])
464
    n_test_before = len(data_split['test'])
465
466
    # Read and process PDB files
467
    n_samples_after = {}
468
    for split in data_split.keys():
469
        lig_coords = []
470
        lig_one_hot = []
471
        lig_mask = []
472
        pocket_coords = []
473
        pocket_one_hot = []
474
        pocket_mask = []
475
        pdb_and_mol_ids = []
476
        receptors = []
477
        count = 0
478
479
        pdb_sdf_dir = processed_dir / split
480
        pdb_sdf_dir.mkdir(exist_ok=True)
481
482
        n_tot = len(data_split[split])
483
        pair_dict = ligand_list_to_dict(data_split[split])
484
485
        tic = time()
486
        num_failed = 0
487
        with tqdm(total=n_tot) as pbar:
488
            for p in pair_dict:
489
490
                pdb_successful = set()
491
492
                # try all available .bio files
493
                for pdbfile in sorted(pdbdir.glob(f"{p.lower()}.bio*")):
494
495
                    # Skip if all ligands have been processed already
496
                    if len(pair_dict[p]) == len(pdb_successful):
497
                        continue
498
499
                    pdb_struct = PDBParser(QUIET=True).get_structure('', pdbfile)
500
                    struct_copy = pdb_struct.copy()
501
502
                    n_bio_successful = 0
503
                    for m in pair_dict[p]:
504
505
                        # Skip already processed ligand
506
                        if m[0] in pdb_successful:
507
                            continue
508
509
                        ligand_name, ligand_chain, ligand_resi = m[0].split(':')
510
                        ligand_resi = int(ligand_resi)
511
512
                        try:
513
                            ligand_data, pocket_data = process_ligand_and_pocket(
514
                                pdb_struct, ligand_name, ligand_chain, ligand_resi,
515
                                dist_cutoff=args.dist_cutoff, ca_only=args.ca_only)
516
                        except (KeyError, AssertionError, FileNotFoundError,
517
                                IndexError, ValueError) as e:
518
                            # print(type(e).__name__, e)
519
                            continue
520
521
                        pdb_and_mol_ids.append(f"{p}_{m[0]}")
522
                        receptors.append(pdbfile.name)
523
                        lig_coords.append(ligand_data['lig_coords'])
524
                        lig_one_hot.append(ligand_data['lig_one_hot'])
525
                        lig_mask.append(
526
                            count * np.ones(len(ligand_data['lig_coords'])))
527
                        pocket_coords.append(pocket_data['pocket_coords'])
528
                        # pocket_quaternion.append(
529
                        #     pocket_data['pocket_quaternion'])
530
                        pocket_one_hot.append(pocket_data['pocket_one_hot'])
531
                        pocket_mask.append(
532
                            count * np.ones(len(pocket_data['pocket_coords'])))
533
                        count += 1
534
535
                        pdb_successful.add(m[0])
536
                        n_bio_successful += 1
537
538
                        # Save additional files for affinity analysis
539
                        if split in {'val', 'test'}:
540
                        # if split in {'val', 'test', 'train'}:
541
                            # remove ligand from receptor
542
                            try:
543
                                struct_copy[0][ligand_chain].detach_child((f'H_{ligand_name}', ligand_resi, ' '))
544
                            except KeyError:
545
                                warnings.warn(f"Could not find ligand {(f'H_{ligand_name}', ligand_resi, ' ')} in {pdbfile}")
546
                                continue
547
548
                            # Create SDF file
549
                            atom_types = [atom_decoder[np.argmax(i)] for i in ligand_data['lig_one_hot']]
550
                            xyz_file = Path(pdb_sdf_dir, 'tmp.xyz')
551
                            utils.write_xyz_file(ligand_data['lig_coords'], atom_types, xyz_file)
552
553
                            obConversion = openbabel.OBConversion()
554
                            obConversion.SetInAndOutFormats("xyz", "sdf")
555
                            mol = openbabel.OBMol()
556
                            obConversion.ReadFile(mol, str(xyz_file))
557
                            xyz_file.unlink()
558
559
                            name = f"{p}-{pdbfile.suffix[1:]}_{m[0]}"
560
                            sdf_file = Path(pdb_sdf_dir, f'{name}.sdf')
561
                            obConversion.WriteFile(mol, str(sdf_file))
562
563
                            # specify pocket residues
564
                            with open(Path(pdb_sdf_dir, f'{name}.txt'), 'w') as f:
565
                                f.write(' '.join(pocket_data['pocket_ids']))
566
567
                    if split in {'val', 'test'} and n_bio_successful > 0:
568
                    # if split in {'val', 'test', 'train'} and n_bio_successful > 0:
569
                        # create receptor PDB file
570
                        pdb_file_out = Path(pdb_sdf_dir, f'{p}-{pdbfile.suffix[1:]}.pdb')
571
                        io = PDBIO()
572
                        io.set_structure(struct_copy)
573
                        io.save(str(pdb_file_out), select=Model0())
574
575
                pbar.update(len(pair_dict[p]))
576
                num_failed += (len(pair_dict[p]) - len(pdb_successful))
577
                pbar.set_description(f'#failed: {num_failed}')
578
579
580
        lig_coords = np.concatenate(lig_coords, axis=0)
581
        lig_one_hot = np.concatenate(lig_one_hot, axis=0)
582
        lig_mask = np.concatenate(lig_mask, axis=0)
583
        pocket_coords = np.concatenate(pocket_coords, axis=0)
584
        pocket_one_hot = np.concatenate(pocket_one_hot, axis=0)
585
        pocket_mask = np.concatenate(pocket_mask, axis=0)
586
587
        np.savez(processed_dir / f'{split}.npz', names=pdb_and_mol_ids,
588
                 receptors=receptors, lig_coords=lig_coords,
589
                 lig_one_hot=lig_one_hot, lig_mask=lig_mask,
590
                 pocket_coords=pocket_coords, pocket_one_hot=pocket_one_hot,
591
                 pocket_mask=pocket_mask)
592
593
        n_samples_after[split] = len(pdb_and_mol_ids)
594
        print(f"Processing {split} set took {(time() - tic)/60.0:.2f} minutes")
595
596
    # --------------------------------------------------------------------------
597
    # Compute statistics & additional information
598
    # --------------------------------------------------------------------------
599
    with np.load(processed_dir / 'train.npz', allow_pickle=True) as data:
600
        lig_mask = data['lig_mask']
601
        pocket_mask = data['pocket_mask']
602
        lig_coords = data['lig_coords']
603
        lig_one_hot = data['lig_one_hot']
604
        pocket_one_hot = data['pocket_one_hot']
605
606
    # Compute SMILES for all training examples
607
    train_smiles = compute_smiles(lig_coords, lig_one_hot, lig_mask)
608
    np.save(processed_dir / 'train_smiles.npy', train_smiles)
609
610
    # Joint histogram of number of ligand and pocket nodes
611
    n_nodes = get_n_nodes(lig_mask, pocket_mask, smooth_sigma=1.0)
612
    np.save(Path(processed_dir, 'size_distribution.npy'), n_nodes)
613
614
    # Convert bond length dictionaries to arrays for batch processing
615
    bonds1, bonds2, bonds3 = get_bond_length_arrays(atom_dict)
616
617
    # Get bond length definitions for Lennard-Jones potential
618
    rm_LJ = get_lennard_jones_rm(atom_dict)
619
620
    # Get histograms of ligand and pocket node types
621
    atom_hist, aa_hist = get_type_histograms(lig_one_hot, pocket_one_hot,
622
                                             atom_dict, amino_acid_dict)
623
624
    # Create summary string
625
    summary_string = '# SUMMARY\n\n'
626
    summary_string += '# Before processing\n'
627
    summary_string += f'num_samples train: {n_train_before}\n'
628
    summary_string += f'num_samples val: {n_val_before}\n'
629
    summary_string += f'num_samples test: {n_test_before}\n\n'
630
    summary_string += '# After processing\n'
631
    summary_string += f"num_samples train: {n_samples_after['train']}\n"
632
    summary_string += f"num_samples val: {n_samples_after['val']}\n"
633
    summary_string += f"num_samples test: {n_samples_after['test']}\n\n"
634
    summary_string += '# Info\n'
635
    summary_string += f"'atom_encoder': {atom_dict}\n"
636
    summary_string += f"'atom_decoder': {list(atom_dict.keys())}\n"
637
    summary_string += f"'aa_encoder': {amino_acid_dict}\n"
638
    summary_string += f"'aa_decoder': {list(amino_acid_dict.keys())}\n"
639
    summary_string += f"'bonds1': {bonds1.tolist()}\n"
640
    summary_string += f"'bonds2': {bonds2.tolist()}\n"
641
    summary_string += f"'bonds3': {bonds3.tolist()}\n"
642
    summary_string += f"'lennard_jones_rm': {rm_LJ.tolist()}\n"
643
    summary_string += f"'atom_hist': {atom_hist}\n"
644
    summary_string += f"'aa_hist': {aa_hist}\n"
645
    summary_string += f"'n_nodes': {n_nodes.tolist()}\n"
646
647
    # Write summary to text file
648
    with open(processed_dir / 'summary.txt', 'w') as f:
649
        f.write(summary_string)
650
651
    # Print summary
652
    print(summary_string)