<a href="https://colab.research.google.com/github/arneschneuing/DiffSBDD/blob/main/colab/DiffSBDD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DiffSBDD: Structure-based Drug Design with Equivariant Diffusion Models

[**[Paper]**](https://arxiv.org/abs/2210.13695)
[**[Code]**](https://github.com/arneschneuing/DiffSBDD)

Make sure to select `Runtime` -> `Change runtime type` -> `GPU` before you run the script.

<img src="https://raw.githubusercontent.com/arneschneuing/DiffSBDD/main/img/overview.png" height=250>

In [None]:
#@title  Install condacolab (the kernel will be restarted, after that you can execute the remaining cells)
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
#@title Install dependencies (this will take about 5-10 minutes)
%cd /content

import os

commands = [
    "pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu118",
    "pip install pytorch-lightning==1.8.4",
    "pip install wandb==0.13.1",
    "pip install rdkit==2022.3.3",
    "pip install biopython==1.79",
    "pip install imageio==2.21.2",
    "pip install scipy==1.7.3",
    "pip install pyg-lib torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html",
    "pip install networkx==2.8.6",
    "pip install py3Dmol==1.8.1",
    "conda install openbabel -c conda-forge",
    "git clone https://github.com/arneschneuing/DiffSBDD.git",
    "mkdir -p /content/DiffSBDD/checkpoints",
    "wget -P /content/DiffSBDD/checkpoints https://zenodo.org/record/8183747/files/moad_fullatom_cond.ckpt",
    "wget -P /content/DiffSBDD/checkpoints https://zenodo.org/record/8183747/files/moad_fullatom_joint.ckpt",
]

errors = {}

if not os.path.isfile("/content/READY"):
  for cmd in commands:
    # os.system(cmd)
    with os.popen(cmd) as f:
      out = f.read()
      status = f.close()

    if status is not None:
      errors[cmd] = out
      print(f"\n\nAn error occurred while running '{cmd}'\n")
      print("Status:\t", status)
      print("Message:\t", out)

if len(errors) == 0:
  os.system("touch /content/READY")

## Choose target PDB

In [2]:
from google.colab import files
from google.colab import output
output.enable_custom_widget_manager()
import os.path
from pathlib import Path
import urllib
import os

input_dir = Path("/content/input_pdbs/")
output_dir = Path("/content/output_sdfs/")
input_dir.mkdir(exist_ok=True)
output_dir.mkdir(exist_ok=True)

target = "example (3rfm)" #@param ["example (3rfm)", "upload structure"]

if target == "example (3rfm)":
  pdbfile = Path(input_dir, '3rfm.pdb')
  urllib.request.urlretrieve('http://files.rcsb.org/download/3rfm.pdb', pdbfile)

elif target == "upload structure":
  uploaded = files.upload()
  fn = list(uploaded.keys())[0]
  pdbfile = Path(input_dir, fn)
  Path(fn).rename(pdbfile)

## Define binding pocket

You can choose between two options to define the binding pocket:
1. **list of residues:** provide a list where each residue is specified as `<chain_id>:<res_id>`, e.g, `A:1 A:2 A:3 A:4 A:5 A:6 A:7`
2. **reference ligand:** if the uploaded PDB structure contains a reference ligand in the target pocket, you can specify its location as `<chain_id>:<res_id>` and the pocket will be extracted automatically

In [3]:
#@title { run: "auto" }
import ipywidgets as widgets

#@markdown **Note:** This cell is an interactive widget and the values will be updated automatically every time you change them. You do not need to execute the cell again. If you do, the default values will be reinserted.

pocket_definition = "reference ligand" #@param ["list of residues", "reference ligand"]

if pocket_definition == "list of residues":
  print('pocket_residues:')
  w = widgets.Text(value='A:9 A:59 A:60 A:62 A:63 A:64 A:66 A:67 A:80 A:81 A:84 A:85 A:88 A:167 A:168 A:169 A:170 A:172 A:174 A:177 A:181 A:246 A:249 A:250 A:252 A:253 A:256 A:265 A:267 A:270 A:271 A:273 A:274 A:275 A:277 A:278')
  pocket_flag = "--resi_list"
elif pocket_definition == "reference ligand":
  print('reference_ligand:')
  w = widgets.Text(value='A:330')
  pocket_flag = "--ref_ligand"

display(w)

reference_ligand:


Text(value='A:330')

## Settings

Notes:
- `timesteps < 500` is an experimental feature
- `resamplings` and `jump_length` only pertain to the inpainting model

In [4]:
#@markdown ## Sampling
n_samples = 10 #@param {type:"slider", min:1, max:100, step:1}
ligand_nodes = 20 #@param {type:"integer"}

model = "Conditional model (Binding MOAD)" #@param ["Conditional model (Binding MOAD)", "Inpainting model (Binding MOAD)"]
checkpoint = Path('/content', 'DiffSBDD', 'checkpoints', 'moad_fullatom_cond.ckpt') if model == "Conditional model (Binding MOAD)" else Path('DiffSBDD', 'checkpoints', 'moad_fullatom_joint.ckpt')

timesteps = 100 #@param {type:"slider", min:1, max:500, step:1}

#@markdown  ## Inpainting parameters
resamplings = 1 #@param {type:"integer"}
jump_length = 1 #@param {type:"integer"}

#@markdown  ## Post-processing
keep_all_fragments = False #@param {type:"boolean"}
keep_all_fragments = "--all_frags" if keep_all_fragments else ""
sanitize = True #@param {type:"boolean"}
sanitize = "--sanitize" if sanitize else ""
relax = True #@param {type:"boolean"}
relax = "--relax" if relax else ""

In [5]:
#@title Run sampling (this will take a few minutes; runtime depends on the input parameters `n_samples`, `timesteps` etc.)
%%capture
%cd /content/DiffSBDD

import argparse
from pathlib import Path
import torch
import utils
from lightning_modules import LigandPocketDDPM


pdb_id = Path(pdbfile).stem
pocket = w.value

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model
model = LigandPocketDDPM.load_from_checkpoint(checkpoint, map_location=device)
model = model.to(device)

num_nodes_lig = torch.ones(n_samples, dtype=int) * ligand_nodes

if pocket_flag == '--ref_ligand':
  resi_list = None
  ref_ligand = pocket
else:
  resi_list = pocket.split()
  ref_ligand = None

molecules = model.generate_ligands(
    pdbfile, n_samples, resi_list, ref_ligand,
    num_nodes_lig, (sanitize == '--sanitize'),
    largest_frag=not (keep_all_fragments == "--all_frags"),
    relax_iter=(200 if (relax == "--relax") else 0),
    resamplings=resamplings, jump_length=jump_length,
    timesteps=timesteps
)

# Make SDF files
utils.write_sdf_file(Path(output_dir, f'{pdb_id}_mol.sdf'), molecules)

In [None]:
#@title Show generated molecules

import sys
sys.path.append("/usr/local/lib/python3.9/site-packages")
import py3Dmol

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(pdbfile, 'r').read(), 'pdb')
view.setStyle({'model': -1}, {'cartoon': {'color': 'lime'}})
# view.addSurface(py3Dmol.VDW, {'opacity': 0.4, 'color': 'lime'})
view.addModelsAsFrames(open(Path(output_dir, f"{pdbfile.stem}_mol.sdf"), 'r').read())
view.setStyle({'model': -1}, {'stick': {}})
view.zoomTo({'model': -1})
view.zoom(0.5)
if target == "example (3rfm)":
  view.rotate(90, 'y')
view.animate({'loop': "forward", 'interval': 1000})
view.show()

In [None]:
#@title Download .sdf file
files.download(Path(output_dir, f"{pdbfile.stem}_mol.sdf"))