Diff of /optimize.py [000000] .. [607087]

Switch to unified view

a b/optimize.py
1
import argparse
2
from pathlib import Path
3
4
import numpy as np
5
import torch
6
import torch.nn.functional as F
7
from Bio.PDB import PDBParser
8
from rdkit import Chem
9
import pandas as pd
10
import random
11
from torch_scatter import scatter_mean
12
from openbabel import openbabel
13
openbabel.obErrorLog.StopLogging()  # suppress OpenBabel messages
14
15
import utils
16
from lightning_modules import LigandPocketDDPM
17
from constants import FLOAT_TYPE, INT_TYPE
18
from analysis.molecule_builder import build_molecule, process_molecule
19
from analysis.metrics import MoleculeProperties
20
21
22
def prepare_from_sdf_files(sdf_files, atom_encoder):
23
24
    ligand_coords = []
25
    atom_one_hot = []
26
    for file in sdf_files:
27
        rdmol = Chem.SDMolSupplier(str(file), sanitize=False)[0]
28
        ligand_coords.append(
29
            torch.from_numpy(rdmol.GetConformer().GetPositions()).float()
30
        )
31
        types = torch.tensor([atom_encoder[a.GetSymbol()] for a in rdmol.GetAtoms()])
32
        atom_one_hot.append(
33
            F.one_hot(types, num_classes=len(atom_encoder))
34
        )
35
36
    return torch.cat(ligand_coords, dim=0), torch.cat(atom_one_hot, dim=0)
37
38
39
def prepare_ligands_from_mols(mols, atom_encoder, device='cpu'):
40
41
    ligand_coords = []
42
    atom_one_hots = []
43
    masks = []
44
    sizes = []
45
    for i, mol in enumerate(mols):
46
        coord = torch.tensor(mol.GetConformer().GetPositions(), dtype=FLOAT_TYPE)
47
        types = torch.tensor([atom_encoder[a.GetSymbol()] for a in mol.GetAtoms()], dtype=INT_TYPE)
48
        one_hot = F.one_hot(types, num_classes=len(atom_encoder))
49
        mask = torch.ones(len(types), dtype=INT_TYPE) * i
50
        ligand_coords.append(coord)
51
        atom_one_hots.append(one_hot)
52
        masks.append(mask)
53
        sizes.append(len(types))
54
55
    ligand = {
56
        'x': torch.cat(ligand_coords, dim=0).to(device),
57
        'one_hot': torch.cat(atom_one_hots, dim=0).to(device),
58
        'size': torch.tensor(sizes, dtype=INT_TYPE).to(device),
59
        'mask': torch.cat(masks, dim=0).to(device),
60
    }
61
62
    return ligand
63
64
65
def prepare_ligand_from_pdb(biopython_atoms, atom_encoder):
66
67
    coord = torch.tensor(np.array([a.get_coord()
68
                                   for a in biopython_atoms]), dtype=FLOAT_TYPE)
69
    types = torch.tensor([atom_encoder[a.element.capitalize()]
70
                          for a in biopython_atoms])
71
    one_hot = F.one_hot(types, num_classes=len(atom_encoder))
72
73
    return coord, one_hot
74
75
76
def prepare_substructure(ref_ligand, fix_atoms, pdb_model):
77
78
    if fix_atoms[0].endswith(".sdf"):
79
        # ligand as sdf file
80
        coord, one_hot = prepare_from_sdf_files(fix_atoms, model.lig_type_encoder)
81
82
    else:
83
        # ligand contained in PDB; given in <chain>:<resi> format
84
        chain, resi = ref_ligand.split(':')
85
        ligand = utils.get_residue_with_resi(pdb_model[chain], int(resi))
86
        fixed_atoms = [a for a in ligand.get_atoms() if a.get_name() in set(fix_atoms)]
87
        coord, one_hot = prepare_ligand_from_pdb(fixed_atoms, model.lig_type_encoder)
88
89
    return coord, one_hot
90
91
92
def diversify_ligands(model, pocket, mols, timesteps,
93
                    sanitize=False,
94
                    largest_frag=False,
95
                    relax_iter=0):
96
    """
97
    Diversify ligands for a specified pocket.
98
    
99
    Parameters:
100
        model: The model instance used for diversification.
101
        pocket: The pocket information including coordinates and types.
102
        mols: List of RDKit molecule objects to be diversified.
103
        timesteps: Number of denoising steps to apply during diversification.
104
        sanitize: If True, performs molecule sanitization post-generation (default: False).
105
        largest_frag: If True, only the largest fragment of the generated molecule is returned (default: False).
106
        relax_iter: Number of iterations for force field relaxation of the generated molecules (default: 0).
107
    
108
    Returns:
109
        A list of diversified RDKit molecule objects.
110
    """
111
112
    ligand = prepare_ligands_from_mols(mols, model.lig_type_encoder, device=model.device)
113
114
    pocket_mask = pocket['mask']
115
    lig_mask = ligand['mask']
116
117
    # Pocket's center of mass
118
    pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0)
119
120
    out_lig, out_pocket, _, _ = model.ddpm.diversify(ligand, pocket, noising_steps=timesteps)
121
122
    # Move generated molecule back to the original pocket position
123
    pocket_com_after = scatter_mean(out_pocket[:, :model.x_dims], pocket_mask, dim=0)
124
125
    out_pocket[:, :model.x_dims] += \
126
        (pocket_com_before - pocket_com_after)[pocket_mask]
127
    out_lig[:, :model.x_dims] += \
128
        (pocket_com_before - pocket_com_after)[lig_mask]
129
130
    # Build mol objects
131
    x = out_lig[:, :model.x_dims].detach().cpu()
132
    atom_type = out_lig[:, model.x_dims:].argmax(1).detach().cpu()
133
134
    molecules = []
135
    for mol_pc in zip(utils.batch_to_list(x, lig_mask),
136
                      utils.batch_to_list(atom_type, lig_mask)):
137
138
        mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True)
139
        mol = process_molecule(mol,
140
                               add_hydrogens=False,
141
                               sanitize=sanitize,
142
                               relax_iter=relax_iter,
143
                               largest_frag=largest_frag)
144
        if mol is not None:
145
            molecules.append(mol)
146
147
    return molecules
148
149
150
if __name__ == "__main__":
151
152
    parser = argparse.ArgumentParser()
153
    parser.add_argument('--checkpoint', type=Path, default='checkpoints/crossdocked_fullatom_cond.ckpt')
154
    parser.add_argument('--pdbfile', type=str, default='example/5ndu.pdb')
155
    parser.add_argument('--ref_ligand', type=str, default='example/5ndu_linked_mols.sdf')
156
    parser.add_argument('--objective', type=str, default='sa', choices={'qed', 'sa'})
157
    parser.add_argument('--timesteps', type=int, default=100)
158
    parser.add_argument('--population_size', type=int, default=100)
159
    parser.add_argument('--evolution_steps', type=int, default=10)
160
    parser.add_argument('--top_k', type=int, default=7)
161
    parser.add_argument('--outfile', type=Path, default='output.sdf')
162
    parser.add_argument('--relax', action='store_true')
163
164
165
    args = parser.parse_args()
166
167
    pdb_id = Path(args.pdbfile).stem
168
169
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
170
    population_size = args.population_size
171
    evolution_steps = args.evolution_steps
172
    top_k = args.top_k
173
174
    # Load model
175
    model = LigandPocketDDPM.load_from_checkpoint(
176
        args.checkpoint, map_location=device)
177
    model = model.to(device)
178
179
    # Prepare ligand + pocket
180
    # Load PDB
181
    pdb_model = PDBParser(QUIET=True).get_structure('', args.pdbfile)[0]
182
    # Define pocket based on reference ligand
183
    residues = utils.get_pocket_from_ligand(pdb_model, args.ref_ligand)
184
    pocket = model.prepare_pocket(residues, repeats=population_size)
185
186
187
    if args.objective == 'qed':
188
        objective_function = MoleculeProperties().calculate_qed
189
    elif args.objective == 'sa':
190
        objective_function = MoleculeProperties().calculate_sa
191
    else:
192
        ### IMPLEMENT YOUR OWN OBJECTIVE
193
        ### FUNCTIONS HERE 
194
        raise ValueError(f"Objective function {args.objective} not recognized.")
195
196
    ref_mol = Chem.SDMolSupplier(args.ref_ligand)[0]
197
198
    # Store molecules in history dataframe 
199
    buffer = pd.DataFrame(columns=['generation', 'score', 'fate' 'mol', 'smiles'])
200
201
    # Population initialization
202
    buffer = buffer.append({'generation': 0,
203
                            'score': objective_function(ref_mol),
204
                            'fate': 'initial', 'mol': ref_mol,
205
                            'smiles': Chem.MolToSmiles(ref_mol)}, ignore_index=True)
206
207
    for generation_idx in range(evolution_steps):
208
209
        if generation_idx == 0:
210
            molecules = buffer['mol'].tolist() * population_size
211
        else:
212
            # Select top k molecules from previous generation
213
            previous_gen = buffer[buffer['generation'] == generation_idx]
214
            top_k_molecules = previous_gen.nlargest(top_k, 'score')['mol'].tolist()
215
            molecules = top_k_molecules * (population_size // top_k)
216
217
            # Update the fate of selected top k molecules in the buffer
218
            buffer.loc[buffer['generation'] == generation_idx, 'fate'] = 'survived'
219
220
            # Ensure the right number of molecules
221
            if len(molecules) < population_size:
222
                molecules += [random.choice(molecules) for _ in range(population_size - len(molecules))]
223
224
225
        # Diversify molecules
226
        assert len(molecules) == population_size, f"Wrong number of molecules: {len(molecules)} when it should be {population_size}"
227
        print(f"Generation {generation_idx}, mean score: {np.mean([objective_function(mol) for mol in molecules])}")
228
        molecules = diversify_ligands(model,
229
                                    pocket,
230
                                    molecules,
231
                                timesteps=args.timesteps,
232
                                sanitize=True,
233
                                relax_iter=(200 if args.relax else 0))
234
        
235
        
236
        # Evaluate and save molecules
237
        for mol in molecules:
238
            buffer = buffer.append({'generation': generation_idx + 1,
239
            'score': objective_function(mol),
240
            'fate': 'purged',
241
            'mol': mol,
242
            'smiles': Chem.MolToSmiles(mol)}, ignore_index=True)
243
244
245
    # Make SDF files
246
    utils.write_sdf_file(args.outfile, molecules)
247
    # Save buffer
248
    buffer.drop(columns=['mol'])
249
    buffer.to_csv(args.outfile.with_suffix('.csv'))