Diff of /colab/DiffSBDD.ipynb [000000] .. [607087]

Switch to side-by-side view

--- a
+++ b/colab/DiffSBDD.ipynb
@@ -0,0 +1,468 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "include_colab_link": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    },
+    "accelerator": "GPU",
+    "widgets": {
+      "application/vnd.jupyter.widget-state+json": {
+        "e293d4b7fb5e4d60a9303400a571d481": {
+          "model_module": "@jupyter-widgets/controls",
+          "model_name": "TextModel",
+          "model_module_version": "1.5.0",
+          "state": {
+            "_dom_classes": [],
+            "_model_module": "@jupyter-widgets/controls",
+            "_model_module_version": "1.5.0",
+            "_model_name": "TextModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/controls",
+            "_view_module_version": "1.5.0",
+            "_view_name": "TextView",
+            "continuous_update": true,
+            "description": "",
+            "description_tooltip": null,
+            "disabled": false,
+            "layout": "IPY_MODEL_9ccaa9faa36646678a0278cd651c3f30",
+            "placeholder": "​",
+            "style": "IPY_MODEL_50a9c77fae984dab908378418b726eff",
+            "value": "A:330"
+          }
+        },
+        "9ccaa9faa36646678a0278cd651c3f30": {
+          "model_module": "@jupyter-widgets/base",
+          "model_name": "LayoutModel",
+          "model_module_version": "1.2.0",
+          "state": {
+            "_model_module": "@jupyter-widgets/base",
+            "_model_module_version": "1.2.0",
+            "_model_name": "LayoutModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/base",
+            "_view_module_version": "1.2.0",
+            "_view_name": "LayoutView",
+            "align_content": null,
+            "align_items": null,
+            "align_self": null,
+            "border": null,
+            "bottom": null,
+            "display": null,
+            "flex": null,
+            "flex_flow": null,
+            "grid_area": null,
+            "grid_auto_columns": null,
+            "grid_auto_flow": null,
+            "grid_auto_rows": null,
+            "grid_column": null,
+            "grid_gap": null,
+            "grid_row": null,
+            "grid_template_areas": null,
+            "grid_template_columns": null,
+            "grid_template_rows": null,
+            "height": null,
+            "justify_content": null,
+            "justify_items": null,
+            "left": null,
+            "margin": null,
+            "max_height": null,
+            "max_width": null,
+            "min_height": null,
+            "min_width": null,
+            "object_fit": null,
+            "object_position": null,
+            "order": null,
+            "overflow": null,
+            "overflow_x": null,
+            "overflow_y": null,
+            "padding": null,
+            "right": null,
+            "top": null,
+            "visibility": null,
+            "width": null
+          }
+        },
+        "50a9c77fae984dab908378418b726eff": {
+          "model_module": "@jupyter-widgets/controls",
+          "model_name": "DescriptionStyleModel",
+          "model_module_version": "1.5.0",
+          "state": {
+            "_model_module": "@jupyter-widgets/controls",
+            "_model_module_version": "1.5.0",
+            "_model_name": "DescriptionStyleModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/base",
+            "_view_module_version": "1.2.0",
+            "_view_name": "StyleView",
+            "description_width": ""
+          }
+        }
+      }
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "view-in-github",
+        "colab_type": "text"
+      },
+      "source": [
+        "<a href=\"https://colab.research.google.com/github/arneschneuing/DiffSBDD/blob/main/colab/DiffSBDD.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# DiffSBDD: Structure-based Drug Design with Equivariant Diffusion Models\n",
+        "\n",
+        "[**[Paper]**](https://arxiv.org/abs/2210.13695)\n",
+        "[**[Code]**](https://github.com/arneschneuing/DiffSBDD)\n",
+        "\n",
+        "Make sure to select `Runtime` -> `Change runtime type` -> `GPU` before you run the script.\n",
+        "\n",
+        "<img src=\"https://raw.githubusercontent.com/arneschneuing/DiffSBDD/main/img/overview.png\" height=250>"
+      ],
+      "metadata": {
+        "id": "m12HrhIsNKkS"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@title  Install condacolab (the kernel will be restarted, after that you can execute the remaining cells)\n",
+        "!pip install -q condacolab\n",
+        "import condacolab\n",
+        "condacolab.install()"
+      ],
+      "metadata": {
+        "cellView": "form",
+        "id": "iAiLo8O8klSG"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@title Install dependencies (this will take about 5-10 minutes)\n",
+        "%cd /content\n",
+        "\n",
+        "import os\n",
+        "\n",
+        "commands = [\n",
+        "    \"pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu118\",\n",
+        "    \"pip install pytorch-lightning==1.8.4\",\n",
+        "    \"pip install wandb==0.13.1\",\n",
+        "    \"pip install rdkit==2022.3.3\",\n",
+        "    \"pip install biopython==1.79\",\n",
+        "    \"pip install imageio==2.21.2\",\n",
+        "    \"pip install scipy==1.7.3\",\n",
+        "    \"pip install pyg-lib torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html\",\n",
+        "    \"pip install networkx==2.8.6\",\n",
+        "    \"pip install py3Dmol==1.8.1\",\n",
+        "    \"conda install openbabel -c conda-forge\",\n",
+        "    \"git clone https://github.com/arneschneuing/DiffSBDD.git\",\n",
+        "    \"mkdir -p /content/DiffSBDD/checkpoints\",\n",
+        "    \"wget -P /content/DiffSBDD/checkpoints https://zenodo.org/record/8183747/files/moad_fullatom_cond.ckpt\",\n",
+        "    \"wget -P /content/DiffSBDD/checkpoints https://zenodo.org/record/8183747/files/moad_fullatom_joint.ckpt\",\n",
+        "]\n",
+        "\n",
+        "errors = {}\n",
+        "\n",
+        "if not os.path.isfile(\"/content/READY\"):\n",
+        "  for cmd in commands:\n",
+        "    # os.system(cmd)\n",
+        "    with os.popen(cmd) as f:\n",
+        "      out = f.read()\n",
+        "      status = f.close()\n",
+        "\n",
+        "    if status is not None:\n",
+        "      errors[cmd] = out\n",
+        "      print(f\"\\n\\nAn error occurred while running '{cmd}'\\n\")\n",
+        "      print(\"Status:\\t\", status)\n",
+        "      print(\"Message:\\t\", out)\n",
+        "\n",
+        "if len(errors) == 0:\n",
+        "  os.system(\"touch /content/READY\")"
+      ],
+      "metadata": {
+        "id": "bZ5Q_kLdVIp7",
+        "cellView": "form"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Choose target PDB"
+      ],
+      "metadata": {
+        "id": "e46yNUYbdqZ3"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from google.colab import files\n",
+        "from google.colab import output\n",
+        "output.enable_custom_widget_manager()\n",
+        "import os.path\n",
+        "from pathlib import Path\n",
+        "import urllib\n",
+        "import os\n",
+        "\n",
+        "input_dir = Path(\"/content/input_pdbs/\")\n",
+        "output_dir = Path(\"/content/output_sdfs/\")\n",
+        "input_dir.mkdir(exist_ok=True)\n",
+        "output_dir.mkdir(exist_ok=True)\n",
+        "\n",
+        "target = \"example (3rfm)\" #@param [\"example (3rfm)\", \"upload structure\"]\n",
+        "\n",
+        "if target == \"example (3rfm)\":\n",
+        "  pdbfile = Path(input_dir, '3rfm.pdb')\n",
+        "  urllib.request.urlretrieve('http://files.rcsb.org/download/3rfm.pdb', pdbfile)\n",
+        "\n",
+        "elif target == \"upload structure\":\n",
+        "  uploaded = files.upload()\n",
+        "  fn = list(uploaded.keys())[0]\n",
+        "  pdbfile = Path(input_dir, fn)\n",
+        "  Path(fn).rename(pdbfile)"
+      ],
+      "metadata": {
+        "cellView": "form",
+        "id": "tzkQJJeNdJMa"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Define binding pocket\n",
+        "\n",
+        "You can choose between two options to define the binding pocket:\n",
+        "1. **list of residues:** provide a list where each residue is specified as `<chain_id>:<res_id>`, e.g, `A:1 A:2 A:3 A:4 A:5 A:6 A:7`\n",
+        "2. **reference ligand:** if the uploaded PDB structure contains a reference ligand in the target pocket, you can specify its location as `<chain_id>:<res_id>` and the pocket will be extracted automatically"
+      ],
+      "metadata": {
+        "id": "eif7HG0Bd2qj"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@title { run: \"auto\" }\n",
+        "import ipywidgets as widgets\n",
+        "\n",
+        "#@markdown **Note:** This cell is an interactive widget and the values will be updated automatically every time you change them. You do not need to execute the cell again. If you do, the default values will be reinserted.\n",
+        "\n",
+        "pocket_definition = \"reference ligand\" #@param [\"list of residues\", \"reference ligand\"]\n",
+        "\n",
+        "if pocket_definition == \"list of residues\":\n",
+        "  print('pocket_residues:')\n",
+        "  w = widgets.Text(value='A:9 A:59 A:60 A:62 A:63 A:64 A:66 A:67 A:80 A:81 A:84 A:85 A:88 A:167 A:168 A:169 A:170 A:172 A:174 A:177 A:181 A:246 A:249 A:250 A:252 A:253 A:256 A:265 A:267 A:270 A:271 A:273 A:274 A:275 A:277 A:278')\n",
+        "  pocket_flag = \"--resi_list\"\n",
+        "elif pocket_definition == \"reference ligand\":\n",
+        "  print('reference_ligand:')\n",
+        "  w = widgets.Text(value='A:330')\n",
+        "  pocket_flag = \"--ref_ligand\"\n",
+        "\n",
+        "display(w)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 67,
+          "referenced_widgets": [
+            "e293d4b7fb5e4d60a9303400a571d481",
+            "9ccaa9faa36646678a0278cd651c3f30",
+            "50a9c77fae984dab908378418b726eff"
+          ]
+        },
+        "cellView": "form",
+        "id": "E9AMhAo_VlUo",
+        "outputId": "b4f9f05f-123c-40d4-c4bd-77ef5e19cb63"
+      },
+      "execution_count": 3,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "reference_ligand:\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "Text(value='A:330')"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "e293d4b7fb5e4d60a9303400a571d481"
+            }
+          },
+          "metadata": {
+            "application/vnd.jupyter.widget-view+json": {
+              "colab": {
+                "custom_widget_manager": {
+                  "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/b3e629b1971e1542/manager.min.js"
+                }
+              }
+            }
+          }
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Settings\n",
+        "\n",
+        "Notes:\n",
+        "- `timesteps < 500` is an experimental feature\n",
+        "- `resamplings` and `jump_length` only pertain to the inpainting model"
+      ],
+      "metadata": {
+        "id": "eYgXjygkeG14"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@markdown ## Sampling\n",
+        "n_samples = 10 #@param {type:\"slider\", min:1, max:100, step:1}\n",
+        "ligand_nodes = 20 #@param {type:\"integer\"}\n",
+        "\n",
+        "model = \"Conditional model (Binding MOAD)\" #@param [\"Conditional model (Binding MOAD)\", \"Inpainting model (Binding MOAD)\"]\n",
+        "checkpoint = Path('/content', 'DiffSBDD', 'checkpoints', 'moad_fullatom_cond.ckpt') if model == \"Conditional model (Binding MOAD)\" else Path('DiffSBDD', 'checkpoints', 'moad_fullatom_joint.ckpt')\n",
+        "\n",
+        "timesteps = 100 #@param {type:\"slider\", min:1, max:500, step:1}\n",
+        "\n",
+        "#@markdown  ## Inpainting parameters\n",
+        "resamplings = 1 #@param {type:\"integer\"}\n",
+        "jump_length = 1 #@param {type:\"integer\"}\n",
+        "\n",
+        "#@markdown  ## Post-processing\n",
+        "keep_all_fragments = False #@param {type:\"boolean\"}\n",
+        "keep_all_fragments = \"--all_frags\" if keep_all_fragments else \"\"\n",
+        "sanitize = True #@param {type:\"boolean\"}\n",
+        "sanitize = \"--sanitize\" if sanitize else \"\"\n",
+        "relax = True #@param {type:\"boolean\"}\n",
+        "relax = \"--relax\" if relax else \"\""
+      ],
+      "metadata": {
+        "id": "VQ6xa7EPMtyI",
+        "cellView": "form"
+      },
+      "execution_count": 4,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@title Run sampling (this will take a few minutes; runtime depends on the input parameters `n_samples`, `timesteps` etc.)\n",
+        "%%capture\n",
+        "%cd /content/DiffSBDD\n",
+        "\n",
+        "import argparse\n",
+        "from pathlib import Path\n",
+        "import torch\n",
+        "import utils\n",
+        "from lightning_modules import LigandPocketDDPM\n",
+        "\n",
+        "\n",
+        "pdb_id = Path(pdbfile).stem\n",
+        "pocket = w.value\n",
+        "\n",
+        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+        "\n",
+        "# Load model\n",
+        "model = LigandPocketDDPM.load_from_checkpoint(checkpoint, map_location=device)\n",
+        "model = model.to(device)\n",
+        "\n",
+        "num_nodes_lig = torch.ones(n_samples, dtype=int) * ligand_nodes\n",
+        "\n",
+        "if pocket_flag == '--ref_ligand':\n",
+        "  resi_list = None\n",
+        "  ref_ligand = pocket\n",
+        "else:\n",
+        "  resi_list = pocket.split()\n",
+        "  ref_ligand = None\n",
+        "\n",
+        "molecules = model.generate_ligands(\n",
+        "    pdbfile, n_samples, resi_list, ref_ligand,\n",
+        "    num_nodes_lig, (sanitize == '--sanitize'),\n",
+        "    largest_frag=not (keep_all_fragments == \"--all_frags\"),\n",
+        "    relax_iter=(200 if (relax == \"--relax\") else 0),\n",
+        "    resamplings=resamplings, jump_length=jump_length,\n",
+        "    timesteps=timesteps\n",
+        ")\n",
+        "\n",
+        "# Make SDF files\n",
+        "utils.write_sdf_file(Path(output_dir, f'{pdb_id}_mol.sdf'), molecules)"
+      ],
+      "metadata": {
+        "id": "pWitBCDUoRBw",
+        "cellView": "form"
+      },
+      "execution_count": 5,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@title Show generated molecules\n",
+        "\n",
+        "import sys\n",
+        "sys.path.append(\"/usr/local/lib/python3.9/site-packages\")\n",
+        "import py3Dmol\n",
+        "\n",
+        "view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
+        "view.addModel(open(pdbfile, 'r').read(), 'pdb')\n",
+        "view.setStyle({'model': -1}, {'cartoon': {'color': 'lime'}})\n",
+        "# view.addSurface(py3Dmol.VDW, {'opacity': 0.4, 'color': 'lime'})\n",
+        "view.addModelsAsFrames(open(Path(output_dir, f\"{pdbfile.stem}_mol.sdf\"), 'r').read())\n",
+        "view.setStyle({'model': -1}, {'stick': {}})\n",
+        "view.zoomTo({'model': -1})\n",
+        "view.zoom(0.5)\n",
+        "if target == \"example (3rfm)\":\n",
+        "  view.rotate(90, 'y')\n",
+        "view.animate({'loop': \"forward\", 'interval': 1000})\n",
+        "view.show()"
+      ],
+      "metadata": {
+        "cellView": "form",
+        "id": "lVyysoc0Rp_-"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@title Download .sdf file\n",
+        "files.download(Path(output_dir, f\"{pdbfile.stem}_mol.sdf\"))"
+      ],
+      "metadata": {
+        "cellView": "form",
+        "id": "3lQUv8rmQRd_"
+      },
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file