--- a +++ b/optimize.py @@ -0,0 +1,249 @@ +import argparse +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from Bio.PDB import PDBParser +from rdkit import Chem +import pandas as pd +import random +from torch_scatter import scatter_mean +from openbabel import openbabel +openbabel.obErrorLog.StopLogging() # suppress OpenBabel messages + +import utils +from lightning_modules import LigandPocketDDPM +from constants import FLOAT_TYPE, INT_TYPE +from analysis.molecule_builder import build_molecule, process_molecule +from analysis.metrics import MoleculeProperties + + +def prepare_from_sdf_files(sdf_files, atom_encoder): + + ligand_coords = [] + atom_one_hot = [] + for file in sdf_files: + rdmol = Chem.SDMolSupplier(str(file), sanitize=False)[0] + ligand_coords.append( + torch.from_numpy(rdmol.GetConformer().GetPositions()).float() + ) + types = torch.tensor([atom_encoder[a.GetSymbol()] for a in rdmol.GetAtoms()]) + atom_one_hot.append( + F.one_hot(types, num_classes=len(atom_encoder)) + ) + + return torch.cat(ligand_coords, dim=0), torch.cat(atom_one_hot, dim=0) + + +def prepare_ligands_from_mols(mols, atom_encoder, device='cpu'): + + ligand_coords = [] + atom_one_hots = [] + masks = [] + sizes = [] + for i, mol in enumerate(mols): + coord = torch.tensor(mol.GetConformer().GetPositions(), dtype=FLOAT_TYPE) + types = torch.tensor([atom_encoder[a.GetSymbol()] for a in mol.GetAtoms()], dtype=INT_TYPE) + one_hot = F.one_hot(types, num_classes=len(atom_encoder)) + mask = torch.ones(len(types), dtype=INT_TYPE) * i + ligand_coords.append(coord) + atom_one_hots.append(one_hot) + masks.append(mask) + sizes.append(len(types)) + + ligand = { + 'x': torch.cat(ligand_coords, dim=0).to(device), + 'one_hot': torch.cat(atom_one_hots, dim=0).to(device), + 'size': torch.tensor(sizes, dtype=INT_TYPE).to(device), + 'mask': torch.cat(masks, dim=0).to(device), + } + + return ligand + + +def prepare_ligand_from_pdb(biopython_atoms, atom_encoder): + + coord = torch.tensor(np.array([a.get_coord() + for a in biopython_atoms]), dtype=FLOAT_TYPE) + types = torch.tensor([atom_encoder[a.element.capitalize()] + for a in biopython_atoms]) + one_hot = F.one_hot(types, num_classes=len(atom_encoder)) + + return coord, one_hot + + +def prepare_substructure(ref_ligand, fix_atoms, pdb_model): + + if fix_atoms[0].endswith(".sdf"): + # ligand as sdf file + coord, one_hot = prepare_from_sdf_files(fix_atoms, model.lig_type_encoder) + + else: + # ligand contained in PDB; given in <chain>:<resi> format + chain, resi = ref_ligand.split(':') + ligand = utils.get_residue_with_resi(pdb_model[chain], int(resi)) + fixed_atoms = [a for a in ligand.get_atoms() if a.get_name() in set(fix_atoms)] + coord, one_hot = prepare_ligand_from_pdb(fixed_atoms, model.lig_type_encoder) + + return coord, one_hot + + +def diversify_ligands(model, pocket, mols, timesteps, + sanitize=False, + largest_frag=False, + relax_iter=0): + """ + Diversify ligands for a specified pocket. + + Parameters: + model: The model instance used for diversification. + pocket: The pocket information including coordinates and types. + mols: List of RDKit molecule objects to be diversified. + timesteps: Number of denoising steps to apply during diversification. + sanitize: If True, performs molecule sanitization post-generation (default: False). + largest_frag: If True, only the largest fragment of the generated molecule is returned (default: False). + relax_iter: Number of iterations for force field relaxation of the generated molecules (default: 0). + + Returns: + A list of diversified RDKit molecule objects. + """ + + ligand = prepare_ligands_from_mols(mols, model.lig_type_encoder, device=model.device) + + pocket_mask = pocket['mask'] + lig_mask = ligand['mask'] + + # Pocket's center of mass + pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0) + + out_lig, out_pocket, _, _ = model.ddpm.diversify(ligand, pocket, noising_steps=timesteps) + + # Move generated molecule back to the original pocket position + pocket_com_after = scatter_mean(out_pocket[:, :model.x_dims], pocket_mask, dim=0) + + out_pocket[:, :model.x_dims] += \ + (pocket_com_before - pocket_com_after)[pocket_mask] + out_lig[:, :model.x_dims] += \ + (pocket_com_before - pocket_com_after)[lig_mask] + + # Build mol objects + x = out_lig[:, :model.x_dims].detach().cpu() + atom_type = out_lig[:, model.x_dims:].argmax(1).detach().cpu() + + molecules = [] + for mol_pc in zip(utils.batch_to_list(x, lig_mask), + utils.batch_to_list(atom_type, lig_mask)): + + mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True) + mol = process_molecule(mol, + add_hydrogens=False, + sanitize=sanitize, + relax_iter=relax_iter, + largest_frag=largest_frag) + if mol is not None: + molecules.append(mol) + + return molecules + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--checkpoint', type=Path, default='checkpoints/crossdocked_fullatom_cond.ckpt') + parser.add_argument('--pdbfile', type=str, default='example/5ndu.pdb') + parser.add_argument('--ref_ligand', type=str, default='example/5ndu_linked_mols.sdf') + parser.add_argument('--objective', type=str, default='sa', choices={'qed', 'sa'}) + parser.add_argument('--timesteps', type=int, default=100) + parser.add_argument('--population_size', type=int, default=100) + parser.add_argument('--evolution_steps', type=int, default=10) + parser.add_argument('--top_k', type=int, default=7) + parser.add_argument('--outfile', type=Path, default='output.sdf') + parser.add_argument('--relax', action='store_true') + + + args = parser.parse_args() + + pdb_id = Path(args.pdbfile).stem + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + population_size = args.population_size + evolution_steps = args.evolution_steps + top_k = args.top_k + + # Load model + model = LigandPocketDDPM.load_from_checkpoint( + args.checkpoint, map_location=device) + model = model.to(device) + + # Prepare ligand + pocket + # Load PDB + pdb_model = PDBParser(QUIET=True).get_structure('', args.pdbfile)[0] + # Define pocket based on reference ligand + residues = utils.get_pocket_from_ligand(pdb_model, args.ref_ligand) + pocket = model.prepare_pocket(residues, repeats=population_size) + + + if args.objective == 'qed': + objective_function = MoleculeProperties().calculate_qed + elif args.objective == 'sa': + objective_function = MoleculeProperties().calculate_sa + else: + ### IMPLEMENT YOUR OWN OBJECTIVE + ### FUNCTIONS HERE + raise ValueError(f"Objective function {args.objective} not recognized.") + + ref_mol = Chem.SDMolSupplier(args.ref_ligand)[0] + + # Store molecules in history dataframe + buffer = pd.DataFrame(columns=['generation', 'score', 'fate' 'mol', 'smiles']) + + # Population initialization + buffer = buffer.append({'generation': 0, + 'score': objective_function(ref_mol), + 'fate': 'initial', 'mol': ref_mol, + 'smiles': Chem.MolToSmiles(ref_mol)}, ignore_index=True) + + for generation_idx in range(evolution_steps): + + if generation_idx == 0: + molecules = buffer['mol'].tolist() * population_size + else: + # Select top k molecules from previous generation + previous_gen = buffer[buffer['generation'] == generation_idx] + top_k_molecules = previous_gen.nlargest(top_k, 'score')['mol'].tolist() + molecules = top_k_molecules * (population_size // top_k) + + # Update the fate of selected top k molecules in the buffer + buffer.loc[buffer['generation'] == generation_idx, 'fate'] = 'survived' + + # Ensure the right number of molecules + if len(molecules) < population_size: + molecules += [random.choice(molecules) for _ in range(population_size - len(molecules))] + + + # Diversify molecules + assert len(molecules) == population_size, f"Wrong number of molecules: {len(molecules)} when it should be {population_size}" + print(f"Generation {generation_idx}, mean score: {np.mean([objective_function(mol) for mol in molecules])}") + molecules = diversify_ligands(model, + pocket, + molecules, + timesteps=args.timesteps, + sanitize=True, + relax_iter=(200 if args.relax else 0)) + + + # Evaluate and save molecules + for mol in molecules: + buffer = buffer.append({'generation': generation_idx + 1, + 'score': objective_function(mol), + 'fate': 'purged', + 'mol': mol, + 'smiles': Chem.MolToSmiles(mol)}, ignore_index=True) + + + # Make SDF files + utils.write_sdf_file(args.outfile, molecules) + # Save buffer + buffer.drop(columns=['mol']) + buffer.to_csv(args.outfile.with_suffix('.csv'))