[480cb7]: / ML_4.ipynb

Download this file

951 lines (951 with data), 29.9 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyO/Pafq/n0HsCJL97lqmyPJ",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/francescopatane96/Computer_aided_drug_discovery_kit/blob/main/ML_4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Virtual screening"
      ],
      "metadata": {
        "id": "IwnuqmRtAmlY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install rdkit"
      ],
      "metadata": {
        "id": "wS_ngR5dEm9p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PMKUrKwU_N6_"
      },
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "from rdkit import Chem, DataStructs\n",
        "from rdkit.Chem import (\n",
        "    PandasTools,\n",
        "    Draw,\n",
        "    Descriptors,\n",
        "    MACCSkeys,\n",
        "    rdFingerprintGenerator,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Molecules in SMILES format\n",
        "molecule_smiles = [\n",
        "    \"CC1C2C(C3C(C(=O)C(=C(C3(C(=O)C2=C(C4=C1C=CC=C4O)O)O)O)C(=O)N)N(C)C)O\",\n",
        "    \"CC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=C(C=C3)O)N)C(=O)O)C\",\n",
        "    \"C1=COC(=C1)CNC2=CC(=C(C=C2C(=O)O)S(=O)(=O)N)Cl\",\n",
        "    \"CCCCCCCCCCCC(=O)OCCOC(=O)CCCCCCCCCCC\",\n",
        "    \"C1NC2=CC(=C(C=C2S(=O)(=O)N1)S(=O)(=O)N)Cl\",\n",
        "    \"CC1=C(C(CCC1)(C)C)C=CC(=CC=CC(=CC(=O)O)C)C\",\n",
        "    \"CC1(C2CC3C(C(=O)C(=C(C3(C(=O)C2=C(C4=C1C=CC=C4O)O)O)O)C(=O)N)N(C)C)O\",\n",
        "    \"CC1C(CC(=O)C2=C1C=CC=C2O)C(=O)O\",\n",
        "]\n",
        "\n",
        "# List of molecule names\n",
        "molecule_names = [\n",
        "    \"Doxycycline\",\n",
        "    \"Amoxicilline\",\n",
        "    \"Furosemide\",\n",
        "    \"Glycol dilaurate\",\n",
        "    \"Hydrochlorothiazide\",\n",
        "    \"Isotretinoin\",\n",
        "    \"Tetracycline\",\n",
        "    \"Hemi-cycline D\",\n",
        "]"
      ],
      "metadata": {
        "id": "9M3RgOqhEvfE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecules = pd.DataFrame({\"smiles\": molecule_smiles, \"name\": molecule_names})\n",
        "PandasTools.AddMoleculeColumnToFrame(molecules, smilesCol=\"smiles\")\n",
        "# Show first 2 molecules\n",
        "molecules.head(2)"
      ],
      "metadata": {
        "id": "iJBPaUowEx7y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "Draw.MolsToGridImage(\n",
        "    molecules[\"ROMol\"].to_list(),\n",
        "    molsPerRow=3,\n",
        "    subImgSize=(450, 150),\n",
        "    legends=molecules[\"name\"].to_list(),\n",
        ")"
      ],
      "metadata": {
        "id": "NqrhlCJzE3no"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Note -- we use pandas apply function to apply the MolWt function\n",
        "# to all ROMol objects in the DataFrame\n",
        "molecules[\"molecule_weight\"] = molecules.ROMol.apply(Descriptors.MolWt)\n",
        "# Sort molecules by molecular weight\n",
        "molecules.sort_values([\"molecule_weight\"], ascending=False, inplace=True)"
      ],
      "metadata": {
        "id": "L3BUmluyFA24"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Show only molecule names and molecular weights\n",
        "molecules[[\"smiles\", \"name\", \"molecule_weight\"]]"
      ],
      "metadata": {
        "id": "4ndk80ByFEHw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "Draw.MolsToGridImage(\n",
        "    molecules[\"ROMol\"],\n",
        "    legends=[\n",
        "        f\"{molecule['name']}: {molecule['molecule_weight']:.2f} Da\"\n",
        "        for index, molecule in molecules.iterrows()\n",
        "    ],\n",
        "    subImgSize=(450, 150),\n",
        "    molsPerRow=3,\n",
        ")"
      ],
      "metadata": {
        "id": "QMDormacFJkh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecule = molecules[\"ROMol\"][0]\n",
        "molecule"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EdmfQDZKFPmx",
        "outputId": "8d6f0066-28da-47df-c51c-790e158ceaf0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<rdkit.Chem.rdchem.Mol at 0x7f7ad3a09670>"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "maccs_fp = MACCSkeys.GenMACCSKeys(molecule)"
      ],
      "metadata": {
        "id": "oYpCd1wnFV4p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecules[\"maccs\"] = molecules.ROMol.apply(MACCSkeys.GenMACCSKeys)"
      ],
      "metadata": {
        "id": "Akm0V6NrFaji"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "circular_int_fp = rdFingerprintGenerator.GetCountFPs([molecule])[0]\n",
        "circular_int_fp"
      ],
      "metadata": {
        "id": "XhIpW23cFgoR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Print non-zero elements:\\n{circular_int_fp.GetNonzeroElements()}\")"
      ],
      "metadata": {
        "id": "PUdhmhz9FkSB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Note that the function takes a list as input parameter\n",
        "# (even if we only want to pass one molecule)\n",
        "circular_bit_fp = rdFingerprintGenerator.GetFPs([molecule])[0]\n",
        "circular_bit_fp"
      ],
      "metadata": {
        "id": "5wY0TR5DFocC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Print top 400 fingerprint bits:\\n{circular_bit_fp.ToBitString()[:400]}\")"
      ],
      "metadata": {
        "id": "Kcc-7JeIFspR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecules[\"morgan\"] = rdFingerprintGenerator.GetFPs(molecules[\"ROMol\"].tolist())"
      ],
      "metadata": {
        "id": "wdt0ZnleFx5q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Example molecules\n",
        "molecule1 = molecules[\"ROMol\"][0]\n",
        "molecule2 = molecules[\"ROMol\"][1]\n",
        "\n",
        "# Example fingerprints\n",
        "maccs_fp1 = MACCSkeys.GenMACCSKeys(molecule1)\n",
        "maccs_fp2 = MACCSkeys.GenMACCSKeys(molecule2)"
      ],
      "metadata": {
        "id": "-QcPiu3QF3Eh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "DataStructs.TanimotoSimilarity(maccs_fp1, maccs_fp2)"
      ],
      "metadata": {
        "id": "Y_DmsFphF5vj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Define molecule query and list\n",
        "molecule_query = molecules[\"maccs\"][0]\n",
        "molecule_list = molecules[\"maccs\"].to_list()\n",
        "# Calculate similarty values between query and list elements\n",
        "molecules[\"tanimoto_maccs\"] = DataStructs.BulkTanimotoSimilarity(molecule_query, molecule_list)\n",
        "molecules[\"dice_maccs\"] = DataStructs.BulkDiceSimilarity(molecule_query, molecule_list)"
      ],
      "metadata": {
        "id": "KXZq_0ghGDyK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "preview = molecules.sort_values([\"tanimoto_maccs\"], ascending=False).reset_index()\n",
        "preview[[\"name\", \"tanimoto_maccs\", \"dice_maccs\"]]"
      ],
      "metadata": {
        "id": "oLBgsIQDGIrB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def draw_ranked_molecules(molecules, sort_by_column):\n",
        "    \"\"\"\n",
        "    Draw molecules sorted by a given column.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    molecules : pandas.DataFrame\n",
        "        Molecules (with \"ROMol\" and \"name\" columns and a column to sort by.\n",
        "    sort_by_column : str\n",
        "        Name of the column used to sort the molecules by.\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    Draw.MolsToGridImage\n",
        "        2D visualization of sorted molecules.\n",
        "    \"\"\"\n",
        "\n",
        "    molecules_sorted = molecules.sort_values([sort_by_column], ascending=False).reset_index()\n",
        "    return Draw.MolsToGridImage(\n",
        "        molecules_sorted[\"ROMol\"],\n",
        "        legends=[\n",
        "            f\"#{index+1} {molecule['name']}, similarity={molecule[sort_by_column]:.2f}\"\n",
        "            for index, molecule in molecules_sorted.iterrows()\n",
        "        ],\n",
        "        molsPerRow=3,\n",
        "        subImgSize=(450, 150),\n",
        "    )"
      ],
      "metadata": {
        "id": "3XGUf57tGRHI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "draw_ranked_molecules(molecules, \"tanimoto_maccs\")"
      ],
      "metadata": {
        "id": "1GaWEV_xGV8I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Define molecule query and list\n",
        "molecule_query = molecules[\"morgan\"][0]\n",
        "molecule_list = molecules[\"morgan\"].to_list()\n",
        "# Calculate similarty values between query and list elements\n",
        "molecules[\"tanimoto_morgan\"] = DataStructs.BulkTanimotoSimilarity(molecule_query, molecule_list)\n",
        "molecules[\"dice_morgan\"] = DataStructs.BulkDiceSimilarity(molecule_query, molecule_list)"
      ],
      "metadata": {
        "id": "IODdVQqKGgMB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "preview = molecules.sort_values([\"tanimoto_morgan\"], ascending=False).reset_index()\n",
        "preview[[\"name\", \"tanimoto_morgan\", \"dice_morgan\", \"tanimoto_maccs\", \"dice_maccs\"]]\n"
      ],
      "metadata": {
        "id": "fivuZHjlGipq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "draw_ranked_molecules(molecules, \"tanimoto_morgan\")"
      ],
      "metadata": {
        "id": "UHv5-rSKGqbt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fig, ax = plt.subplots(figsize=(6, 6))\n",
        "molecules.plot(\"tanimoto_maccs\", \"tanimoto_morgan\", kind=\"scatter\", ax=ax)\n",
        "ax.plot([0, 1], [0, 1], \"k--\")\n",
        "ax.set_xlabel(\"Tanimoto (MACCS)\")\n",
        "ax.set_ylabel(\"Tanimoto (Morgan)\")\n",
        "fig;"
      ],
      "metadata": {
        "id": "3o5w6sDyGwYN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecule_dataset = pd.read_csv(\n",
        "    \"TNFB_compounds_lipinski.csv\",\n",
        "    usecols=[\"molecule_chembl_id\", \"smiles\", \"pIC50\"],\n",
        ")\n",
        "print(f\"Number of molecules in dataset: {len(molecule_dataset)}\")\n",
        "molecule_dataset.head(5)"
      ],
      "metadata": {
        "id": "wH-Q8hyv9XZn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "query = Chem.MolFromSmiles(\"O=C(Nc1cc(C(F)(F)F)cc(C(F)(F)F)c1)c1cnc(Cl)nc1C(F)(F)F\")\n",
        "query"
      ],
      "metadata": {
        "id": "3WmxmCzj97WQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "maccs_fp_query = MACCSkeys.GenMACCSKeys(query)\n",
        "circular_fp_query = rdFingerprintGenerator.GetCountFPs([query])[0]"
      ],
      "metadata": {
        "id": "GkcDaxez-OuP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "PandasTools.AddMoleculeColumnToFrame(molecule_dataset, \"smiles\")\n",
        "circular_fp_list = rdFingerprintGenerator.GetCountFPs(molecule_dataset[\"ROMol\"].tolist())\n",
        "maccs_fp_list = molecule_dataset[\"ROMol\"].apply(MACCSkeys.GenMACCSKeys).tolist()"
      ],
      "metadata": {
        "id": "n2bzjaPD-SfH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecule_dataset[\"tanimoto_maccs\"] = DataStructs.BulkTanimotoSimilarity(\n",
        "    maccs_fp_query, maccs_fp_list\n",
        ")\n",
        "molecule_dataset[\"tanimoto_morgan\"] = DataStructs.BulkTanimotoSimilarity(\n",
        "    circular_fp_query, circular_fp_list\n",
        ")"
      ],
      "metadata": {
        "id": "fxxsJrac-WAw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecule_dataset[\"dice_maccs\"] = DataStructs.BulkDiceSimilarity(maccs_fp_query, maccs_fp_list)\n",
        "molecule_dataset[\"dice_morgan\"] = DataStructs.BulkDiceSimilarity(\n",
        "    circular_fp_query, circular_fp_list\n",
        ")"
      ],
      "metadata": {
        "id": "PQT4ako_-YoL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# NBVAL_CHECK_OUTPUT\n",
        "molecule_dataset[\n",
        "    [\"smiles\", \"tanimoto_maccs\", \"tanimoto_morgan\", \"dice_maccs\", \"dice_morgan\"]\n",
        "].head(5)"
      ],
      "metadata": {
        "id": "6ESskO7s-aVb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Show all columns\n",
        "molecule_dataset.head(3)"
      ],
      "metadata": {
        "id": "4lvx2Tx--fkQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fig, axes = plt.subplots(figsize=(10, 6), nrows=2, ncols=2)\n",
        "molecule_dataset.hist([\"tanimoto_maccs\"], ax=axes[0, 0])\n",
        "molecule_dataset.hist([\"tanimoto_morgan\"], ax=axes[0, 1])\n",
        "molecule_dataset.hist([\"dice_maccs\"], ax=axes[1, 0])\n",
        "molecule_dataset.hist([\"dice_morgan\"], ax=axes[1, 1])\n",
        "axes[1, 0].set_xlabel(\"similarity value\")\n",
        "axes[1, 0].set_ylabel(\"# molecules\")\n",
        "fig;"
      ],
      "metadata": {
        "id": "7mySir-w-kah"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fig, axes = plt.subplots(figsize=(12, 6), nrows=1, ncols=2)\n",
        "\n",
        "molecule_dataset.plot(\"tanimoto_maccs\", \"dice_maccs\", kind=\"scatter\", ax=axes[0])\n",
        "axes[0].plot([0, 1], [0, 1], \"k--\")\n",
        "axes[0].set_xlabel(\"Tanimoto (MACCS)\")\n",
        "axes[0].set_ylabel(\"Dice (MACCS)\")\n",
        "\n",
        "molecule_dataset.plot(\"tanimoto_morgan\", \"dice_morgan\", kind=\"scatter\", ax=axes[1])\n",
        "axes[1].plot([0, 1], [0, 1], \"k--\")\n",
        "axes[1].set_xlabel(\"Tanimoto (Morgan)\")\n",
        "axes[1].set_ylabel(\"Dice (Morgan)\")\n",
        "\n",
        "fig;"
      ],
      "metadata": {
        "id": "FMB6Qty1-pbn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecule_dataset.sort_values([\"tanimoto_morgan\"], ascending=False).head(3)"
      ],
      "metadata": {
        "id": "4UzFAWc6-ykA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "top_n_molecules = 10\n",
        "top_molecules = molecule_dataset.sort_values([\"tanimoto_morgan\"], ascending=False).reset_index()\n",
        "top_molecules = top_molecules[:top_n_molecules]\n",
        "legends = [\n",
        "    f\"#{index+1} {molecule['molecule_chembl_id']}, pIC50={molecule['pIC50']:.2f}\"\n",
        "    for index, molecule in top_molecules.iterrows()\n",
        "]\n",
        "Chem.Draw.MolsToGridImage(\n",
        "    mols=[query] + top_molecules[\"ROMol\"].tolist(),\n",
        "    legends=([\"Gefitinib\"] + legends),\n",
        "    molsPerRow=3,\n",
        "    subImgSize=(450, 150),\n",
        ")"
      ],
      "metadata": {
        "id": "BgcxSuAj-6gv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "molecule_dataset.head(3)"
      ],
      "metadata": {
        "id": "u0tFvI4Y_Msf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_enrichment_data(molecules, similarity_measure, pic50_cutoff):\n",
        "    \"\"\"\n",
        "    Calculates x and y values for enrichment plot:\n",
        "        x - % ranked dataset\n",
        "        y - % true actives identified\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    molecules : pandas.DataFrame\n",
        "        Molecules with similarity values to a query molecule.\n",
        "    similarity_measure : str\n",
        "        Column name which will be used to sort the DataFrame.\n",
        "    pic50_cutoff : float\n",
        "        pIC50 cutoff value used to discriminate active and inactive molecules.\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    pandas.DataFrame\n",
        "        Enrichment data: Percentage of ranked dataset by similarity vs. percentage of identified true actives.\n",
        "    \"\"\"\n",
        "\n",
        "    # Get number of molecules in data set\n",
        "    molecules_all = len(molecules)\n",
        "\n",
        "    # Get number of active molecules in data set\n",
        "    actives_all = sum(molecules[\"pIC50\"] >= pic50_cutoff)\n",
        "\n",
        "    # Initialize a list that will hold the counter for actives and molecules while iterating through our dataset\n",
        "    actives_counter_list = []\n",
        "\n",
        "    # Initialize counter for actives\n",
        "    actives_counter = 0\n",
        "\n",
        "    # Note: Data must be ranked for enrichment plots:\n",
        "    # Sort molecules by selected similarity measure\n",
        "    molecules.sort_values([similarity_measure], ascending=False, inplace=True)\n",
        "\n",
        "    # Iterate over the ranked dataset and check each molecule if active (by checking bioactivity)\n",
        "    for value in molecules[\"pIC50\"]:\n",
        "        if value >= pic50_cutoff:\n",
        "            actives_counter += 1\n",
        "        actives_counter_list.append(actives_counter)\n",
        "\n",
        "    # Transform number of molecules into % ranked dataset\n",
        "    molecules_percentage_list = [i / molecules_all for i in range(1, molecules_all + 1)]\n",
        "\n",
        "    # Transform number of actives into % true actives identified\n",
        "    actives_percentage_list = [i / actives_all for i in actives_counter_list]\n",
        "\n",
        "    # Generate DataFrame with x and y values as well as label\n",
        "    enrichment = pd.DataFrame(\n",
        "        {\n",
        "            \"% ranked dataset\": molecules_percentage_list,\n",
        "            \"% true actives identified\": actives_percentage_list,\n",
        "        }\n",
        "    )\n",
        "\n",
        "    return enrichment"
      ],
      "metadata": {
        "id": "GupUIRiM_Pin"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pic50_cutoff = 6.3"
      ],
      "metadata": {
        "id": "KdgWp1l8_T4L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "similarity_measures = [\"tanimoto_maccs\", \"tanimoto_morgan\"]\n",
        "enrichment_data = {\n",
        "    similarity_measure: get_enrichment_data(molecule_dataset, similarity_measure, pic50_cutoff)\n",
        "    for similarity_measure in similarity_measures\n",
        "}"
      ],
      "metadata": {
        "id": "mEKlMs7q_XG_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# NBVAL_CHECK_OUTPUT\n",
        "enrichment_data[\"tanimoto_maccs\"].head()"
      ],
      "metadata": {
        "id": "58Jqhzk__ZHX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fig, ax = plt.subplots(figsize=(6, 6))\n",
        "\n",
        "fontsize = 20\n",
        "\n",
        "# Plot enrichment data\n",
        "for similarity_measure, enrichment in enrichment_data.items():\n",
        "    ax = enrichment.plot(\n",
        "        ax=ax,\n",
        "        x=\"% ranked dataset\",\n",
        "        y=\"% true actives identified\",\n",
        "        label=similarity_measure,\n",
        "        alpha=0.5,\n",
        "        linewidth=4,\n",
        "    )\n",
        "ax.set_ylabel(\"% True actives identified\", size=fontsize)\n",
        "ax.set_xlabel(\"% Ranked dataset\", size=fontsize)\n",
        "\n",
        "# Plot optimal curve: Ratio of actives in dataset\n",
        "ratio_actives = sum(molecule_dataset[\"pIC50\"] >= pic50_cutoff) / len(molecule_dataset)\n",
        "ax.plot(\n",
        "    [0, ratio_actives, 1],\n",
        "    [0, 1, 1],\n",
        "    label=\"Optimal curve\",\n",
        "    color=\"black\",\n",
        "    linestyle=\"--\",\n",
        ")\n",
        "\n",
        "# Plot random curve\n",
        "ax.plot([0, 1], [0, 1], label=\"Random curve\", color=\"grey\", linestyle=\"--\")\n",
        "\n",
        "plt.tick_params(labelsize=16)\n",
        "plt.legend(\n",
        "    labels=[\"MACCS\", \"Morgan\", \"Optimal\", \"Random\"],\n",
        "    loc=(0.5, 0.08),\n",
        "    fontsize=fontsize,\n",
        "    labelspacing=0.3,\n",
        ")\n",
        "\n",
        "# Save plot -- use bbox_inches to include text boxes\n",
        "plt.savefig(\n",
        "    \"enrichment_plot.png\",\n",
        "    dpi=300,\n",
        "    bbox_inches=\"tight\",\n",
        "    transparent=True,\n",
        ")\n",
        "\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "8is59ZNb_ctY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def calculate_enrichment_factor(enrichment, ranked_dataset_percentage_cutoff):\n",
        "    \"\"\"\n",
        "    Get the experimental enrichment factor for a given percentage of the ranked dataset.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    enrichment : pd.DataFrame\n",
        "        Enrichment data: Percentage of ranked dataset by similarity vs. percentage of\n",
        "        identified true actives.\n",
        "    ranked_dataset_percentage_cutoff : float or int\n",
        "        Percentage of ranked dataset to be included in enrichment factor calculation.\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    float\n",
        "        Experimental enrichment factor.\n",
        "    \"\"\"\n",
        "\n",
        "    # Keep only molecules that meet the cutoff\n",
        "    enrichment = enrichment[\n",
        "        enrichment[\"% ranked dataset\"] <= ranked_dataset_percentage_cutoff / 100\n",
        "    ]\n",
        "    # Get highest percentage of actives and the corresponding percentage of actives\n",
        "    highest_enrichment = enrichment.iloc[-1]\n",
        "    enrichment_factor = round(100 * float(highest_enrichment[\"% true actives identified\"]), 1)\n",
        "    return enrichment_factor"
      ],
      "metadata": {
        "id": "XUWbwjU7_uBR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def calculate_enrichment_factor_random(ranked_dataset_percentage_cutoff):\n",
        "    \"\"\"\n",
        "    Get the random enrichment factor for a given percentage of the ranked dataset.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    ranked_dataset_percentage_cutoff : float or int\n",
        "        Percentage of ranked dataset to be included in enrichment factor calculation.\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    float\n",
        "        Random enrichment factor.\n",
        "    \"\"\"\n",
        "\n",
        "    enrichment_factor_random = round(float(ranked_dataset_percentage_cutoff), 1)\n",
        "    return enrichment_factor_random"
      ],
      "metadata": {
        "id": "wbjqjFRr_wqT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def calculate_enrichment_factor_optimal(molecules, ranked_dataset_percentage_cutoff, pic50_cutoff):\n",
        "    \"\"\"\n",
        "    Get the optimal random enrichment factor for a given percentage of the ranked dataset.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    molecules : pandas.DataFrame\n",
        "        the DataFrame with all the molecules and pIC50.\n",
        "    ranked_dataset_percentage_cutoff : float or int\n",
        "        Percentage of ranked dataset to be included in enrichment factor calculation.\n",
        "    activity_cutoff: float\n",
        "        pIC50 cutoff value used to discriminate active and inactive molecules\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    float\n",
        "        Optimal enrichment factor.\n",
        "    \"\"\"\n",
        "\n",
        "    ratio = sum(molecules[\"pIC50\"] >= pic50_cutoff) / len(molecules) * 100\n",
        "    if ranked_dataset_percentage_cutoff <= ratio:\n",
        "        enrichment_factor_optimal = round(100 / ratio * ranked_dataset_percentage_cutoff, 1)\n",
        "    else:\n",
        "        enrichment_factor_optimal = 100.0\n",
        "    return enrichment_factor_optimal"
      ],
      "metadata": {
        "id": "2m4I2IHo_0pj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "ranked_dataset_percentage_cutoff = 5"
      ],
      "metadata": {
        "id": "WbAfjid7_4MA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for similarity_measure, enrichment in enrichment_data.items():\n",
        "    enrichment_factor = calculate_enrichment_factor(enrichment, ranked_dataset_percentage_cutoff)\n",
        "    print(\n",
        "        f\"Experimental EF for {ranked_dataset_percentage_cutoff}% of ranked dataset ({similarity_measure}): {enrichment_factor}%\"\n",
        "    )"
      ],
      "metadata": {
        "id": "osJhc5pK_5zv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "enrichment_factor_random = calculate_enrichment_factor_random(ranked_dataset_percentage_cutoff)\n",
        "print(\n",
        "    f\"Random EF for {ranked_dataset_percentage_cutoff}% of ranked dataset: {enrichment_factor_random}%\"\n",
        ")\n",
        "enrichment_factor_optimal = calculate_enrichment_factor_optimal(\n",
        "    molecule_dataset, ranked_dataset_percentage_cutoff, pic50_cutoff\n",
        ")\n",
        "print(\n",
        "    f\"Optimal EF for {ranked_dataset_percentage_cutoff}% of ranked dataset: {enrichment_factor_optimal}%\"\n",
        ")\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "id": "dY2atRHUACoz"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}