[480cb7]: / 7_ligandbased_screening.ipynb

Download this file

2083 lines (2083 with data), 249.0 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyNmOGC0GGNjIbcmUUTxZoYs",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/francescopatane96/Computer_aided_drug_discovery_kit/blob/main/7_ligandbased_screening.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 7. Ligand-based screening with ML"
      ],
      "metadata": {
        "id": "mnDnpeJiHC7z"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rr8OrkQkGzf7"
      },
      "outputs": [],
      "source": [
        "pip install rdkit"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from pathlib import Path\n",
        "from warnings import filterwarnings\n",
        "import time\n",
        "\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "from sklearn import svm, metrics, clone\n",
        "from sklearn.ensemble import RandomForestClassifier\n",
        "from sklearn.neural_network import MLPClassifier\n",
        "from sklearn.model_selection import KFold, train_test_split\n",
        "from sklearn.metrics import auc, accuracy_score, recall_score\n",
        "from sklearn.metrics import roc_curve, roc_auc_score\n",
        "import matplotlib.pyplot as plt\n",
        "from rdkit import Chem\n",
        "from rdkit.Chem import MACCSkeys, rdFingerprintGenerator\n",
        "\n",
        "\n",
        "\n",
        "# Silence some expected warnings\n",
        "filterwarnings(\"ignore\")\n",
        "# Fix seed for reproducible results\n",
        "\n"
      ],
      "metadata": {
        "id": "Hy_noqtzIfZv"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Read data from previous talktorials\n",
        "chembl_df = pd.read_csv(\n",
        "    \"compounds_lipinski.csv\",\n",
        "    index_col=0,\n",
        ")\n",
        "\n",
        "# Look at head\n",
        "print(\"Shape of dataframe : \", chembl_df.shape)\n",
        "chembl_df.head()\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 461
        },
        "id": "jyV1NfgZIlrX",
        "outputId": "921fc759-7f54-41e5-8679-27fc822422a7"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Shape of dataframe :  (5431, 15)\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "  molecule_chembl_id   IC50 units                               smiles  \\\n",
              "0        CHEMBL63786  0.003    nM    Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1   \n",
              "1        CHEMBL53711  0.006    nM   CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1   \n",
              "2        CHEMBL35820  0.006    nM  CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC   \n",
              "3        CHEMBL53753  0.008    nM      CNc1cc2c(Nc3cccc(Br)c3)ncnc2cn1   \n",
              "4        CHEMBL66031  0.008    nM  Brc1cccc(Nc2ncnc3cc4[nH]cnc4cc23)c1   \n",
              "\n",
              "       pIC50                                             ROMol  \\\n",
              "0  11.522879  <rdkit.Chem.rdchem.Mol object at 0x7c4034736490>   \n",
              "1  11.221849  <rdkit.Chem.rdchem.Mol object at 0x7c4034736500>   \n",
              "2  11.221849  <rdkit.Chem.rdchem.Mol object at 0x7c4034736570>   \n",
              "3  11.096910  <rdkit.Chem.rdchem.Mol object at 0x7c4034736420>   \n",
              "4  11.096910  <rdkit.Chem.rdchem.Mol object at 0x7c40347365e0>   \n",
              "\n",
              "   molecular_weight  n_hba  n_hbd    logp  molecular_weight.1  n_hba.1  \\\n",
              "0        349.021459      3      1  5.2891          349.021459        3   \n",
              "1        343.043258      5      1  3.5969          343.043258        5   \n",
              "2        387.058239      5      1  4.9333          387.058239        5   \n",
              "3        329.027607      5      2  3.5726          329.027607        5   \n",
              "4        339.011957      4      2  4.0122          339.011957        4   \n",
              "\n",
              "   n_hbd.1  logp.1  ro5_fulfilled  \n",
              "0        1  5.2891           True  \n",
              "1        1  3.5969           True  \n",
              "2        1  4.9333           True  \n",
              "3        2  3.5726           True  \n",
              "4        2  4.0122           True  "
            ],
            "text/html": [
              "\n",
              "\n",
              "  <div id=\"df-d47f121f-d018-4045-824b-2343c672fa24\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>molecule_chembl_id</th>\n",
              "      <th>IC50</th>\n",
              "      <th>units</th>\n",
              "      <th>smiles</th>\n",
              "      <th>pIC50</th>\n",
              "      <th>ROMol</th>\n",
              "      <th>molecular_weight</th>\n",
              "      <th>n_hba</th>\n",
              "      <th>n_hbd</th>\n",
              "      <th>logp</th>\n",
              "      <th>molecular_weight.1</th>\n",
              "      <th>n_hba.1</th>\n",
              "      <th>n_hbd.1</th>\n",
              "      <th>logp.1</th>\n",
              "      <th>ro5_fulfilled</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>CHEMBL63786</td>\n",
              "      <td>0.003</td>\n",
              "      <td>nM</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1</td>\n",
              "      <td>11.522879</td>\n",
              "      <td>&lt;rdkit.Chem.rdchem.Mol object at 0x7c4034736490&gt;</td>\n",
              "      <td>349.021459</td>\n",
              "      <td>3</td>\n",
              "      <td>1</td>\n",
              "      <td>5.2891</td>\n",
              "      <td>349.021459</td>\n",
              "      <td>3</td>\n",
              "      <td>1</td>\n",
              "      <td>5.2891</td>\n",
              "      <td>True</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>CHEMBL53711</td>\n",
              "      <td>0.006</td>\n",
              "      <td>nM</td>\n",
              "      <td>CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>&lt;rdkit.Chem.rdchem.Mol object at 0x7c4034736500&gt;</td>\n",
              "      <td>343.043258</td>\n",
              "      <td>5</td>\n",
              "      <td>1</td>\n",
              "      <td>3.5969</td>\n",
              "      <td>343.043258</td>\n",
              "      <td>5</td>\n",
              "      <td>1</td>\n",
              "      <td>3.5969</td>\n",
              "      <td>True</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>CHEMBL35820</td>\n",
              "      <td>0.006</td>\n",
              "      <td>nM</td>\n",
              "      <td>CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>&lt;rdkit.Chem.rdchem.Mol object at 0x7c4034736570&gt;</td>\n",
              "      <td>387.058239</td>\n",
              "      <td>5</td>\n",
              "      <td>1</td>\n",
              "      <td>4.9333</td>\n",
              "      <td>387.058239</td>\n",
              "      <td>5</td>\n",
              "      <td>1</td>\n",
              "      <td>4.9333</td>\n",
              "      <td>True</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>CHEMBL53753</td>\n",
              "      <td>0.008</td>\n",
              "      <td>nM</td>\n",
              "      <td>CNc1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.096910</td>\n",
              "      <td>&lt;rdkit.Chem.rdchem.Mol object at 0x7c4034736420&gt;</td>\n",
              "      <td>329.027607</td>\n",
              "      <td>5</td>\n",
              "      <td>2</td>\n",
              "      <td>3.5726</td>\n",
              "      <td>329.027607</td>\n",
              "      <td>5</td>\n",
              "      <td>2</td>\n",
              "      <td>3.5726</td>\n",
              "      <td>True</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>CHEMBL66031</td>\n",
              "      <td>0.008</td>\n",
              "      <td>nM</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4[nH]cnc4cc23)c1</td>\n",
              "      <td>11.096910</td>\n",
              "      <td>&lt;rdkit.Chem.rdchem.Mol object at 0x7c40347365e0&gt;</td>\n",
              "      <td>339.011957</td>\n",
              "      <td>4</td>\n",
              "      <td>2</td>\n",
              "      <td>4.0122</td>\n",
              "      <td>339.011957</td>\n",
              "      <td>4</td>\n",
              "      <td>2</td>\n",
              "      <td>4.0122</td>\n",
              "      <td>True</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d47f121f-d018-4045-824b-2343c672fa24')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "\n",
              "\n",
              "\n",
              "    <div id=\"df-4aec8cf4-9618-4e6c-b824-da39057a0374\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-4aec8cf4-9618-4e6c-b824-da39057a0374')\"\n",
              "              title=\"Suggest charts.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "    </div>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "    background-color: #E8F0FE;\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: #1967D2;\n",
              "    height: 32px;\n",
              "    padding: 0 0 0 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: #E2EBFA;\n",
              "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: #174EA6;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "    background-color: #3B4455;\n",
              "    fill: #D2E3FC;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart:hover {\n",
              "    background-color: #434B5C;\n",
              "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "    fill: #FFFFFF;\n",
              "  }\n",
              "</style>\n",
              "\n",
              "    <script>\n",
              "      async function quickchart(key) {\n",
              "        const containerElement = document.querySelector('#' + key);\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      }\n",
              "    </script>\n",
              "\n",
              "      <script>\n",
              "\n",
              "function displayQuickchartButton(domScope) {\n",
              "  let quickchartButtonEl =\n",
              "    domScope.querySelector('#df-4aec8cf4-9618-4e6c-b824-da39057a0374 button.colab-df-quickchart');\n",
              "  quickchartButtonEl.style.display =\n",
              "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "}\n",
              "\n",
              "        displayQuickchartButton(document);\n",
              "      </script>\n",
              "      <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-d47f121f-d018-4045-824b-2343c672fa24 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-d47f121f-d018-4045-824b-2343c672fa24');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n"
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Keep only the columns we want\n",
        "chembl_df = chembl_df[[\"molecule_chembl_id\", \"smiles\", \"pIC50\"]]\n",
        "chembl_df.head()\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "nk9OT7qRIxq_",
        "outputId": "3eefa62d-3dc0-409f-b5c4-5a079cf91b30"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "  molecule_chembl_id                               smiles      pIC50\n",
              "0        CHEMBL63786    Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1  11.522879\n",
              "1        CHEMBL53711   CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1  11.221849\n",
              "2        CHEMBL35820  CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC  11.221849\n",
              "3        CHEMBL53753      CNc1cc2c(Nc3cccc(Br)c3)ncnc2cn1  11.096910\n",
              "4        CHEMBL66031  Brc1cccc(Nc2ncnc3cc4[nH]cnc4cc23)c1  11.096910"
            ],
            "text/html": [
              "\n",
              "\n",
              "  <div id=\"df-34b5d525-b689-46c5-83f2-67d925a1f297\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>molecule_chembl_id</th>\n",
              "      <th>smiles</th>\n",
              "      <th>pIC50</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>CHEMBL63786</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1</td>\n",
              "      <td>11.522879</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>CHEMBL53711</td>\n",
              "      <td>CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.221849</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>CHEMBL35820</td>\n",
              "      <td>CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC</td>\n",
              "      <td>11.221849</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>CHEMBL53753</td>\n",
              "      <td>CNc1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.096910</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>CHEMBL66031</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4[nH]cnc4cc23)c1</td>\n",
              "      <td>11.096910</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-34b5d525-b689-46c5-83f2-67d925a1f297')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "\n",
              "\n",
              "\n",
              "    <div id=\"df-e34d1f7b-7c73-46e2-accf-8745d8a3545d\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-e34d1f7b-7c73-46e2-accf-8745d8a3545d')\"\n",
              "              title=\"Suggest charts.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "    </div>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "    background-color: #E8F0FE;\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: #1967D2;\n",
              "    height: 32px;\n",
              "    padding: 0 0 0 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: #E2EBFA;\n",
              "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: #174EA6;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "    background-color: #3B4455;\n",
              "    fill: #D2E3FC;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart:hover {\n",
              "    background-color: #434B5C;\n",
              "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "    fill: #FFFFFF;\n",
              "  }\n",
              "</style>\n",
              "\n",
              "    <script>\n",
              "      async function quickchart(key) {\n",
              "        const containerElement = document.querySelector('#' + key);\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      }\n",
              "    </script>\n",
              "\n",
              "      <script>\n",
              "\n",
              "function displayQuickchartButton(domScope) {\n",
              "  let quickchartButtonEl =\n",
              "    domScope.querySelector('#df-e34d1f7b-7c73-46e2-accf-8745d8a3545d button.colab-df-quickchart');\n",
              "  quickchartButtonEl.style.display =\n",
              "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "}\n",
              "\n",
              "        displayQuickchartButton(document);\n",
              "      </script>\n",
              "      <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-34b5d525-b689-46c5-83f2-67d925a1f297 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-34b5d525-b689-46c5-83f2-67d925a1f297');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Add column for activity\n",
        "chembl_df[\"active\"] = np.zeros(len(chembl_df))\n",
        "\n",
        "# Mark every molecule as active with an pIC50 of >= 6.3, 0 otherwise\n",
        "chembl_df.loc[chembl_df[chembl_df.pIC50 >= 6.3].index, \"active\"] = 1.0\n",
        "\n",
        "# NBVAL_CHECK_OUTPUT\n",
        "print(\"Number of active compounds:\", int(chembl_df.active.sum()))\n",
        "print(\"Number of inactive compounds:\", len(chembl_df) - int(chembl_df.active.sum()))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eESQTyr-I2XI",
        "outputId": "142c6dae-95a0-4b06-ffa3-1ce277670dc1"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of active compounds: 3240\n",
            "Number of inactive compounds: 2191\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "chembl_df.head()\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "MhlIbYWUI5bG",
        "outputId": "12ff3941-fec5-40b2-8cd0-d2fb2ace6e98"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "  molecule_chembl_id                               smiles      pIC50  active\n",
              "0        CHEMBL63786    Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1  11.522879     1.0\n",
              "1        CHEMBL53711   CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1  11.221849     1.0\n",
              "2        CHEMBL35820  CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC  11.221849     1.0\n",
              "3        CHEMBL53753      CNc1cc2c(Nc3cccc(Br)c3)ncnc2cn1  11.096910     1.0\n",
              "4        CHEMBL66031  Brc1cccc(Nc2ncnc3cc4[nH]cnc4cc23)c1  11.096910     1.0"
            ],
            "text/html": [
              "\n",
              "\n",
              "  <div id=\"df-a0bae0f2-6ab1-4f7e-802d-b1887138244d\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>molecule_chembl_id</th>\n",
              "      <th>smiles</th>\n",
              "      <th>pIC50</th>\n",
              "      <th>active</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>CHEMBL63786</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1</td>\n",
              "      <td>11.522879</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>CHEMBL53711</td>\n",
              "      <td>CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>CHEMBL35820</td>\n",
              "      <td>CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>CHEMBL53753</td>\n",
              "      <td>CNc1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.096910</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>CHEMBL66031</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4[nH]cnc4cc23)c1</td>\n",
              "      <td>11.096910</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a0bae0f2-6ab1-4f7e-802d-b1887138244d')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "\n",
              "\n",
              "\n",
              "    <div id=\"df-509b0650-3274-4e9c-96df-6e5dd98c2d5e\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-509b0650-3274-4e9c-96df-6e5dd98c2d5e')\"\n",
              "              title=\"Suggest charts.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "    </div>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "    background-color: #E8F0FE;\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: #1967D2;\n",
              "    height: 32px;\n",
              "    padding: 0 0 0 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: #E2EBFA;\n",
              "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: #174EA6;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "    background-color: #3B4455;\n",
              "    fill: #D2E3FC;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart:hover {\n",
              "    background-color: #434B5C;\n",
              "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "    fill: #FFFFFF;\n",
              "  }\n",
              "</style>\n",
              "\n",
              "    <script>\n",
              "      async function quickchart(key) {\n",
              "        const containerElement = document.querySelector('#' + key);\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      }\n",
              "    </script>\n",
              "\n",
              "      <script>\n",
              "\n",
              "function displayQuickchartButton(domScope) {\n",
              "  let quickchartButtonEl =\n",
              "    domScope.querySelector('#df-509b0650-3274-4e9c-96df-6e5dd98c2d5e button.colab-df-quickchart');\n",
              "  quickchartButtonEl.style.display =\n",
              "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "}\n",
              "\n",
              "        displayQuickchartButton(document);\n",
              "      </script>\n",
              "      <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-a0bae0f2-6ab1-4f7e-802d-b1887138244d button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-a0bae0f2-6ab1-4f7e-802d-b1887138244d');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def smiles_to_fp(smiles, method=\"maccs\", n_bits=2048):\n",
        "    \"\"\"\n",
        "    Encode a molecule from a SMILES string into a fingerprint.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    smiles : str\n",
        "        The SMILES string defining the molecule.\n",
        "\n",
        "    method : str\n",
        "        The type of fingerprint to use. Default is MACCS keys.\n",
        "\n",
        "    n_bits : int\n",
        "        The length of the fingerprint.\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    array\n",
        "        The fingerprint array.\n",
        "\n",
        "    \"\"\"\n",
        "\n",
        "    # convert smiles to RDKit mol object\n",
        "    mol = Chem.MolFromSmiles(smiles)\n",
        "\n",
        "    if method == \"maccs\":\n",
        "        return np.array(MACCSkeys.GenMACCSKeys(mol))\n",
        "    if method == \"morgan2\":\n",
        "        fpg = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=n_bits)\n",
        "        return np.array(fpg.GetFingerprint(mol))\n",
        "    if method == \"morgan3\":\n",
        "        fpg = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=n_bits)\n",
        "        return np.array(fpg.GetFingerprint(mol))\n",
        "    else:\n",
        "        # NBVAL_CHECK_OUTPUT\n",
        "        print(f\"Warning: Wrong method specified: {method}. Default will be used instead.\")\n",
        "        return np.array(MACCSkeys.GenMACCSKeys(mol))"
      ],
      "metadata": {
        "id": "lWZbQUPrI8vu"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "compound_df = chembl_df.copy()"
      ],
      "metadata": {
        "id": "mGWfhNA_I_Ly"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Add column for fingerprint\n",
        "compound_df[\"fp\"] = compound_df[\"smiles\"].apply(smiles_to_fp)\n",
        "compound_df.head(3)\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 344
        },
        "id": "oa1OCn5tJBbe",
        "outputId": "4f366685-12e3-487f-8d24-a3f1fe586b6c"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "  molecule_chembl_id                               smiles      pIC50  active  \\\n",
              "0        CHEMBL63786    Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1  11.522879     1.0   \n",
              "1        CHEMBL53711   CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1  11.221849     1.0   \n",
              "2        CHEMBL35820  CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC  11.221849     1.0   \n",
              "\n",
              "                                                  fp  \n",
              "0  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  \n",
              "1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  \n",
              "2  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  "
            ],
            "text/html": [
              "\n",
              "\n",
              "  <div id=\"df-8ee2aeb8-05ee-4a9f-b912-997e4dbc2503\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>molecule_chembl_id</th>\n",
              "      <th>smiles</th>\n",
              "      <th>pIC50</th>\n",
              "      <th>active</th>\n",
              "      <th>fp</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>CHEMBL63786</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1</td>\n",
              "      <td>11.522879</td>\n",
              "      <td>1.0</td>\n",
              "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>CHEMBL53711</td>\n",
              "      <td>CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>1.0</td>\n",
              "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>CHEMBL35820</td>\n",
              "      <td>CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>1.0</td>\n",
              "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-8ee2aeb8-05ee-4a9f-b912-997e4dbc2503')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "\n",
              "\n",
              "\n",
              "    <div id=\"df-854687ac-5f96-4ba0-a9ca-779519efa4b2\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-854687ac-5f96-4ba0-a9ca-779519efa4b2')\"\n",
              "              title=\"Suggest charts.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "    </div>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "    background-color: #E8F0FE;\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: #1967D2;\n",
              "    height: 32px;\n",
              "    padding: 0 0 0 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: #E2EBFA;\n",
              "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: #174EA6;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "    background-color: #3B4455;\n",
              "    fill: #D2E3FC;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart:hover {\n",
              "    background-color: #434B5C;\n",
              "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "    fill: #FFFFFF;\n",
              "  }\n",
              "</style>\n",
              "\n",
              "    <script>\n",
              "      async function quickchart(key) {\n",
              "        const containerElement = document.querySelector('#' + key);\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      }\n",
              "    </script>\n",
              "\n",
              "      <script>\n",
              "\n",
              "function displayQuickchartButton(domScope) {\n",
              "  let quickchartButtonEl =\n",
              "    domScope.querySelector('#df-854687ac-5f96-4ba0-a9ca-779519efa4b2 button.colab-df-quickchart');\n",
              "  quickchartButtonEl.style.display =\n",
              "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "}\n",
              "\n",
              "        displayQuickchartButton(document);\n",
              "      </script>\n",
              "      <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-8ee2aeb8-05ee-4a9f-b912-997e4dbc2503 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-8ee2aeb8-05ee-4a9f-b912-997e4dbc2503');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_roc_curves_for_models(models, test_x, test_y, save_png=False):\n",
        "    \"\"\"\n",
        "    Helper function to plot customized roc curve.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    models: dict\n",
        "        Dictionary of pretrained machine learning models.\n",
        "    test_x: list\n",
        "        Molecular fingerprints for test set.\n",
        "    test_y: list\n",
        "        Associated activity labels for test set.\n",
        "    save_png: bool\n",
        "        Save image to disk (default = False)\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    fig:\n",
        "        Figure.\n",
        "    \"\"\"\n",
        "\n",
        "    fig, ax = plt.subplots()\n",
        "\n",
        "    # Below for loop iterates through your models list\n",
        "    for model in models:\n",
        "        # Select the model\n",
        "        ml_model = model[\"model\"]\n",
        "        # Prediction probability on test set\n",
        "        test_prob = ml_model.predict_proba(test_x)[:, 1]\n",
        "        # Prediction class on test set\n",
        "        test_pred = ml_model.predict(test_x)\n",
        "        # Compute False postive rate and True positive rate\n",
        "        fpr, tpr, thresholds = metrics.roc_curve(test_y, test_prob)\n",
        "        # Calculate Area under the curve to display on the plot\n",
        "        auc = roc_auc_score(test_y, test_prob)\n",
        "        # Plot the computed values\n",
        "        ax.plot(fpr, tpr, label=(f\"{model['label']} AUC area = {auc:.2f}\"))\n",
        "\n",
        "    # Custom settings for the plot\n",
        "    ax.plot([0, 1], [0, 1], \"r--\")\n",
        "    ax.set_xlabel(\"False Positive Rate\")\n",
        "    ax.set_ylabel(\"True Positive Rate\")\n",
        "    ax.set_title(\"Receiver Operating Characteristic\")\n",
        "    ax.legend(loc=\"lower right\")\n",
        "    # Save plot\n",
        "    if save_png:\n",
        "        fig.savefig(f\"roc_auc\", dpi=300, bbox_inches=\"tight\", transparent=True)\n",
        "    return fig"
      ],
      "metadata": {
        "id": "qfmwZX0GJM9n"
      },
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def model_performance(ml_model, test_x, test_y, verbose=True):\n",
        "    \"\"\"\n",
        "    Helper function to calculate model performance\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    ml_model: sklearn model object\n",
        "        The machine learning model to train.\n",
        "    test_x: list\n",
        "        Molecular fingerprints for test set.\n",
        "    test_y: list\n",
        "        Associated activity labels for test set.\n",
        "    verbose: bool\n",
        "        Print performance measure (default = True)\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    tuple:\n",
        "        Accuracy, sensitivity, specificity, auc on test set.\n",
        "    \"\"\"\n",
        "\n",
        "    # Prediction probability on test set\n",
        "    test_prob = ml_model.predict_proba(test_x)[:, 1]\n",
        "\n",
        "    # Prediction class on test set\n",
        "    test_pred = ml_model.predict(test_x)\n",
        "\n",
        "    # Performance of model on test set\n",
        "    accuracy = accuracy_score(test_y, test_pred)\n",
        "    sens = recall_score(test_y, test_pred)\n",
        "    spec = recall_score(test_y, test_pred, pos_label=0)\n",
        "    auc = roc_auc_score(test_y, test_prob)\n",
        "\n",
        "    if verbose:\n",
        "        # Print performance results\n",
        "        # NBVAL_CHECK_OUTPUT        print(f\"Accuracy: {accuracy:.2}\")\n",
        "        print(f\"Sensitivity: {sens:.2f}\")\n",
        "        print(f\"Specificity: {spec:.2f}\")\n",
        "        print(f\"AUC: {auc:.2f}\")\n",
        "\n",
        "    return accuracy, sens, spec, auc"
      ],
      "metadata": {
        "id": "W5Dg1n1kJPo-"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def model_training_and_validation(ml_model, name, splits, verbose=True):\n",
        "    \"\"\"\n",
        "    Fit a machine learning model on a random train-test split of the data\n",
        "    and return the performance measures.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    ml_model: sklearn model object\n",
        "        The machine learning model to train.\n",
        "    name: str\n",
        "        Name of machine learning algorithm: RF, SVM, ANN\n",
        "    splits: list\n",
        "        List of desciptor and label data: train_x, test_x, train_y, test_y.\n",
        "    verbose: bool\n",
        "        Print performance info (default = True)\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    tuple:\n",
        "        Accuracy, sensitivity, specificity, auc on test set.\n",
        "\n",
        "    \"\"\"\n",
        "    train_x, test_x, train_y, test_y = splits\n",
        "\n",
        "    # Fit the model\n",
        "    ml_model.fit(train_x, train_y)\n",
        "\n",
        "    # Calculate model performance results\n",
        "    accuracy, sens, spec, auc = model_performance(ml_model, test_x, test_y, verbose)\n",
        "\n",
        "    return accuracy, sens, spec, auc"
      ],
      "metadata": {
        "id": "Wk2Io_jBJR_1"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fingerprint_to_model = compound_df.fp.tolist()\n",
        "label_to_model = compound_df.active.tolist()\n",
        "\n",
        "# Split data randomly in train and test set\n",
        "# note that we use test/train_x for the respective fingerprint splits\n",
        "# and test/train_y for the respective label splits\n",
        "(\n",
        "    static_train_x,\n",
        "    static_test_x,\n",
        "    static_train_y,\n",
        "    static_test_y,\n",
        ") = train_test_split(fingerprint_to_model, label_to_model, test_size=0.2)\n",
        "splits = [static_train_x, static_test_x, static_train_y, static_test_y]\n",
        "# NBVAL_CHECK_OUTPUT\n",
        "print(\"Training data size:\", len(static_train_x))\n",
        "print(\"Test data size:\", len(static_test_x))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "88Y4pIPqJUSy",
        "outputId": "52e0cbad-fcb2-4e99-9457-8bef8cd1f95f"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training data size: 4344\n",
            "Test data size: 1087\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Set model parameter for random forest\n",
        "param = {\n",
        "    \"n_estimators\": 100,  # number of trees to grows\n",
        "    \"criterion\": \"entropy\",  # cost function to be optimized for a split\n",
        "}\n",
        "model_RF = RandomForestClassifier(**param)"
      ],
      "metadata": {
        "id": "uiXhKkQQJbug"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Fit model on single split\n",
        "performance_measures = model_training_and_validation(model_RF, \"RF\", splits)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "srnZbkH4JeDk",
        "outputId": "267637b5-31f6-449f-932f-67bc58cd2930"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sensitivity: 0.86\n",
            "Specificity: 0.72\n",
            "AUC: 0.86\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Initialize the list that stores all models. First one is RF.\n",
        "models = [{\"label\": \"Model_RF\", \"model\": model_RF}]\n",
        "# Plot roc curve\n",
        "plot_roc_curves_for_models(models, static_test_x, static_test_y);"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 472
        },
        "id": "zhK5ki29JhVI",
        "outputId": "bfc92be5-e28a-4320-8872-f29ab083686e"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Specify model\n",
        "model_SVM = svm.SVC(kernel=\"rbf\", C=1, gamma=0.1, probability=True)\n",
        "\n",
        "# Fit model on single split\n",
        "performance_measures = model_training_and_validation(model_SVM, \"SVM\", splits)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "agNe2fMbJkvR",
        "outputId": "a657a622-4cf3-4f29-c80f-f470689011ac"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sensitivity: 0.90\n",
            "Specificity: 0.72\n",
            "AUC: 0.87\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Append SVM model\n",
        "models.append({\"label\": \"Model_SVM\", \"model\": model_SVM})\n",
        "# Plot roc curve\n",
        "plot_roc_curves_for_models(models, static_test_x, static_test_y);"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 472
        },
        "id": "nL1eKszjJnVn",
        "outputId": "21626e12-9690-454c-f7a5-84faac6c63d5"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Specify model\n",
        "model_ANN = MLPClassifier(hidden_layer_sizes=(5, 3))\n",
        "\n",
        "# Fit model on single split\n",
        "performance_measures = model_training_and_validation(model_ANN, \"ANN\", splits)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1m3rVarmJpx0",
        "outputId": "3904f620-b442-4e5d-9f16-e05d1d8b99df"
      },
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sensitivity: 0.88\n",
            "Specificity: 0.70\n",
            "AUC: 0.86\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Append ANN model\n",
        "models.append({\"label\": \"Model_ANN\", \"model\": model_ANN})\n",
        "# Plot roc curve\n",
        "plot_roc_curves_for_models(models, static_test_x, static_test_y, True);"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 472
        },
        "id": "OKL5ro69Jtoc",
        "outputId": "1f691403-b522-4124-df46-69086615288f"
      },
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def crossvalidation(ml_model, df, n_folds=5, verbose=False):\n",
        "    \"\"\"\n",
        "    Machine learning model training and validation in a cross-validation loop.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    ml_model: sklearn model object\n",
        "        The machine learning model to train.\n",
        "    df: pd.DataFrame\n",
        "        Data set with SMILES and their associated activity labels.\n",
        "    n_folds: int, optional\n",
        "        Number of folds for cross-validation.\n",
        "    verbose: bool, optional\n",
        "        Performance measures are printed.\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    None\n",
        "\n",
        "    \"\"\"\n",
        "    t0 = time.time()\n",
        "    # Shuffle the indices for the k-fold cross-validation\n",
        "    kf = KFold(n_splits=n_folds, shuffle=True)\n",
        "\n",
        "    # Results for each of the cross-validation folds\n",
        "    acc_per_fold = []\n",
        "    sens_per_fold = []\n",
        "    spec_per_fold = []\n",
        "    auc_per_fold = []\n",
        "\n",
        "    # Loop over the folds\n",
        "    for train_index, test_index in kf.split(df):\n",
        "        # clone model -- we want a fresh copy per fold!\n",
        "        fold_model = clone(ml_model)\n",
        "        # Training\n",
        "\n",
        "        # Convert the fingerprint and the label to a list\n",
        "        train_x = df.iloc[train_index].fp.tolist()\n",
        "        train_y = df.iloc[train_index].active.tolist()\n",
        "\n",
        "        # Fit the model\n",
        "        fold_model.fit(train_x, train_y)\n",
        "\n",
        "        # Testing\n",
        "\n",
        "        # Convert the fingerprint and the label to a list\n",
        "        test_x = df.iloc[test_index].fp.tolist()\n",
        "        test_y = df.iloc[test_index].active.tolist()\n",
        "\n",
        "        # Performance for each fold\n",
        "        accuracy, sens, spec, auc = model_performance(fold_model, test_x, test_y, verbose)\n",
        "\n",
        "        # Save results\n",
        "        acc_per_fold.append(accuracy)\n",
        "        sens_per_fold.append(sens)\n",
        "        spec_per_fold.append(spec)\n",
        "        auc_per_fold.append(auc)\n",
        "\n",
        "    # Print statistics of results\n",
        "    print(\n",
        "        f\"Mean accuracy: {np.mean(acc_per_fold):.2f} \\t\"\n",
        "        f\"and std : {np.std(acc_per_fold):.2f} \\n\"\n",
        "        f\"Mean sensitivity: {np.mean(sens_per_fold):.2f} \\t\"\n",
        "        f\"and std : {np.std(sens_per_fold):.2f} \\n\"\n",
        "        f\"Mean specificity: {np.mean(spec_per_fold):.2f} \\t\"\n",
        "        f\"and std : {np.std(spec_per_fold):.2f} \\n\"\n",
        "        f\"Mean AUC: {np.mean(auc_per_fold):.2f} \\t\"\n",
        "        f\"and std : {np.std(auc_per_fold):.2f} \\n\"\n",
        "        f\"Time taken : {time.time() - t0:.2f}s\\n\"\n",
        "    )\n",
        "\n",
        "    return acc_per_fold, sens_per_fold, spec_per_fold, auc_per_fold"
      ],
      "metadata": {
        "id": "yxrtPIX2KVPz"
      },
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "N_FOLDS = 3"
      ],
      "metadata": {
        "id": "0E_3bvsuKdu6"
      },
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for model in models:\n",
        "    print(\"\\n======= \")\n",
        "    print(f\"{model['label']}\")\n",
        "    crossvalidation(model[\"model\"], compound_df, n_folds=N_FOLDS)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "S7enDvcdKfu5",
        "outputId": "ad18ee89-bd3f-4e2c-bb52-07cbc589508c"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "======= \n",
            "Model_RF\n",
            "Mean accuracy: 0.82 \tand std : 0.01 \n",
            "Mean sensitivity: 0.88 \tand std : 0.01 \n",
            "Mean specificity: 0.72 \tand std : 0.00 \n",
            "Mean AUC: 0.88 \tand std : 0.00 \n",
            "Time taken : 1.56s\n",
            "\n",
            "\n",
            "======= \n",
            "Model_SVM\n",
            "Mean accuracy: 0.83 \tand std : 0.01 \n",
            "Mean sensitivity: 0.91 \tand std : 0.01 \n",
            "Mean specificity: 0.71 \tand std : 0.00 \n",
            "Mean AUC: 0.89 \tand std : 0.01 \n",
            "Time taken : 15.20s\n",
            "\n",
            "\n",
            "======= \n",
            "Model_ANN\n",
            "Mean accuracy: 0.78 \tand std : 0.00 \n",
            "Mean sensitivity: 0.85 \tand std : 0.03 \n",
            "Mean specificity: 0.67 \tand std : 0.05 \n",
            "Mean AUC: 0.83 \tand std : 0.00 \n",
            "Time taken : 5.18s\n",
            "\n",
            "\n",
            "======= \n",
            "Model_ANN\n",
            "Mean accuracy: 0.78 \tand std : 0.00 \n",
            "Mean sensitivity: 0.86 \tand std : 0.02 \n",
            "Mean specificity: 0.68 \tand std : 0.02 \n",
            "Mean AUC: 0.84 \tand std : 0.01 \n",
            "Time taken : 4.72s\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Reset data frame\n",
        "compound_df = chembl_df.copy()"
      ],
      "metadata": {
        "id": "eUSUe5glKxIp"
      },
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Use Morgan fingerprint with radius 3\n",
        "compound_df[\"fp\"] = compound_df[\"smiles\"].apply(smiles_to_fp, args=(\"morgan3\",))\n",
        "compound_df.head(3)\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 344
        },
        "id": "mhU26x6CKyuh",
        "outputId": "cffb681d-42c1-4a39-d2b1-bd85ac260eff"
      },
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "  molecule_chembl_id                               smiles      pIC50  active  \\\n",
              "0        CHEMBL63786    Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1  11.522879     1.0   \n",
              "1        CHEMBL53711   CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1  11.221849     1.0   \n",
              "2        CHEMBL35820  CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC  11.221849     1.0   \n",
              "\n",
              "                                                  fp  \n",
              "0  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  \n",
              "1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  \n",
              "2  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  "
            ],
            "text/html": [
              "\n",
              "\n",
              "  <div id=\"df-d3f93327-1742-4b26-af2d-c2872008da5d\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>molecule_chembl_id</th>\n",
              "      <th>smiles</th>\n",
              "      <th>pIC50</th>\n",
              "      <th>active</th>\n",
              "      <th>fp</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>CHEMBL63786</td>\n",
              "      <td>Brc1cccc(Nc2ncnc3cc4ccccc4cc23)c1</td>\n",
              "      <td>11.522879</td>\n",
              "      <td>1.0</td>\n",
              "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>CHEMBL53711</td>\n",
              "      <td>CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>1.0</td>\n",
              "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>CHEMBL35820</td>\n",
              "      <td>CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC</td>\n",
              "      <td>11.221849</td>\n",
              "      <td>1.0</td>\n",
              "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d3f93327-1742-4b26-af2d-c2872008da5d')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "\n",
              "\n",
              "\n",
              "    <div id=\"df-0084d01f-898f-4898-a059-1c6bc886ae3c\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-0084d01f-898f-4898-a059-1c6bc886ae3c')\"\n",
              "              title=\"Suggest charts.\"\n",
              "              style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "    </div>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "    background-color: #E8F0FE;\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: #1967D2;\n",
              "    height: 32px;\n",
              "    padding: 0 0 0 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: #E2EBFA;\n",
              "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: #174EA6;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "    background-color: #3B4455;\n",
              "    fill: #D2E3FC;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart:hover {\n",
              "    background-color: #434B5C;\n",
              "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "    fill: #FFFFFF;\n",
              "  }\n",
              "</style>\n",
              "\n",
              "    <script>\n",
              "      async function quickchart(key) {\n",
              "        const containerElement = document.querySelector('#' + key);\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      }\n",
              "    </script>\n",
              "\n",
              "      <script>\n",
              "\n",
              "function displayQuickchartButton(domScope) {\n",
              "  let quickchartButtonEl =\n",
              "    domScope.querySelector('#df-0084d01f-898f-4898-a059-1c6bc886ae3c button.colab-df-quickchart');\n",
              "  quickchartButtonEl.style.display =\n",
              "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "}\n",
              "\n",
              "        displayQuickchartButton(document);\n",
              "      </script>\n",
              "      <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-d3f93327-1742-4b26-af2d-c2872008da5d button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-d3f93327-1742-4b26-af2d-c2872008da5d');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for model in models:\n",
        "    if model[\"label\"] == \"Model_SVM\":\n",
        "        # SVM is super slow with long fingerprints\n",
        "        # and will have a performance similar to RF\n",
        "        # We can skip it in this test, but if you want\n",
        "        # to run it, feel free to replace `continue` with `pass`\n",
        "        continue\n",
        "    print(\"\\n=======\")\n",
        "    print(model[\"label\"])\n",
        "    crossvalidation(model[\"model\"], compound_df, n_folds=N_FOLDS)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "G4khVL1KK1v4",
        "outputId": "71e60815-32d8-4966-cd3f-9c28e8dcdcea"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "=======\n",
            "Model_RF\n",
            "Mean accuracy: 0.85 \tand std : 0.00 \n",
            "Mean sensitivity: 0.90 \tand std : 0.00 \n",
            "Mean specificity: 0.78 \tand std : 0.01 \n",
            "Mean AUC: 0.91 \tand std : 0.00 \n",
            "Time taken : 7.14s\n",
            "\n",
            "\n",
            "=======\n",
            "Model_ANN\n",
            "Mean accuracy: 0.80 \tand std : 0.02 \n",
            "Mean sensitivity: 0.85 \tand std : 0.02 \n",
            "Mean specificity: 0.74 \tand std : 0.04 \n",
            "Mean AUC: 0.85 \tand std : 0.01 \n",
            "Time taken : 25.79s\n",
            "\n",
            "\n",
            "=======\n",
            "Model_ANN\n",
            "Mean accuracy: 0.82 \tand std : 0.01 \n",
            "Mean sensitivity: 0.87 \tand std : 0.00 \n",
            "Mean specificity: 0.75 \tand std : 0.02 \n",
            "Mean AUC: 0.88 \tand std : 0.01 \n",
            "Time taken : 25.99s\n",
            "\n"
          ]
        }
      ]
    }
  ]
}