[480cb7]: / ML_5.ipynb

Download this file

604 lines (604 with data), 20.9 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyOHZYm2RG/vTEZ6jZiHZihY",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/francescopatane96/Computer_aided_drug_discovery_kit/blob/main/ML_5.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install rdkit "
      ],
      "metadata": {
        "id": "2D4euSqkBVDd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install git+https://github.com/volkamerlab/teachopencadd.git"
      ],
      "metadata": {
        "id": "myQilAQCBeKl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "import random\n",
        "from pathlib import Path\n",
        "\n",
        "import pandas as pd\n",
        "import numpy\n",
        "import matplotlib.pyplot as plt\n",
        "from rdkit import Chem\n",
        "from rdkit import DataStructs\n",
        "from rdkit.ML.Cluster import Butina\n",
        "from rdkit.Chem import Draw\n",
        "from rdkit.Chem import rdFingerprintGenerator\n",
        "\n",
        "from teachopencadd.utils import seed_everything\n",
        "\n",
        "seed_everything()  # fix seed to get deterministic outputs"
      ],
      "metadata": {
        "id": "GyKKqPxiBTs2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Load and have a look into data\n",
        "# Filtered data taken from **Talktorial T002**\n",
        "compound_df = pd.read_csv(\n",
        "    \"TNFB_compounds_lipinski.csv\",\n",
        "    index_col=0,\n",
        ")\n",
        "print(\"Dataframe shape:\", compound_df.shape)\n",
        "compound_df.head()"
      ],
      "metadata": {
        "id": "7hQCDG1LC42R"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create molecules from SMILES and store in array\n",
        "compounds = []\n",
        "# .itertuples() returns a (index, column1, column2, ...) tuple per row\n",
        "# we don't need index so we use _ instead\n",
        "# note how we are slicing the dataframe to only the two columns we need now\n",
        "for _, chembl_id, smiles in compound_df[[\"molecule_chembl_id\", \"smiles\"]].itertuples():\n",
        "    compounds.append((Chem.MolFromSmiles(smiles), chembl_id))\n",
        "compounds[:5]"
      ],
      "metadata": {
        "id": "tHvZx_QmDBUu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create fingerprints for all molecules\n",
        "rdkit_gen = rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=5)\n",
        "fingerprints = [rdkit_gen.GetFingerprint(mol) for mol, idx in compounds]\n",
        "\n",
        "# How many compounds/fingerprints do we have?\n",
        "print(\"Number of compounds converted:\", len(fingerprints))\n",
        "print(\"Fingerprint length per compound:\", len(fingerprints[0]))\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "id": "5y-nTbl3DDc-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def tanimoto_distance_matrix(fp_list):\n",
        "    \"\"\"Calculate distance matrix for fingerprint list\"\"\"\n",
        "    dissimilarity_matrix = []\n",
        "    # Notice how we are deliberately skipping the first and last items in the list\n",
        "    # because we don't need to compare them against themselves\n",
        "    for i in range(1, len(fp_list)):\n",
        "        # Compare the current fingerprint against all the previous ones in the list\n",
        "        similarities = DataStructs.BulkTanimotoSimilarity(fp_list[i], fp_list[:i])\n",
        "        # Since we need a distance matrix, calculate 1-x for every element in similarity matrix\n",
        "        dissimilarity_matrix.extend([1 - x for x in similarities])\n",
        "    return dissimilarity_matrix"
      ],
      "metadata": {
        "id": "d9IDVZRADIRu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Example: Calculate single similarity of two fingerprints\n",
        "# NBVAL_CHECK_OUTPUT\n",
        "sim = DataStructs.TanimotoSimilarity(fingerprints[0], fingerprints[1])\n",
        "print(f\"Tanimoto similarity: {sim:.2f}, distance: {1-sim:.2f}\")"
      ],
      "metadata": {
        "id": "QauDa9HyDKVO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Example: Calculate distance matrix (distance = 1-similarity)\n",
        "tanimoto_distance_matrix(fingerprints)[0:5]"
      ],
      "metadata": {
        "id": "yY0K_IF7DMKu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Side note: That looked like a list and not a matrix.\n",
        "# But it is a triangular similarity matrix in the form of a list\n",
        "n = len(fingerprints)\n",
        "\n",
        "# Calculate number of elements in triangular matrix via n*(n-1)/2\n",
        "elem_triangular_matr = (n * (n - 1)) / 2\n",
        "print(\n",
        "    f\"Elements in the triangular matrix ({elem_triangular_matr:.0f}) ==\",\n",
        "    f\"tanimoto_distance_matrix(fingerprints) ({len(tanimoto_distance_matrix(fingerprints))})\",\n",
        ")\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "id": "iccQP_zbDOX-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def cluster_fingerprints(fingerprints, cutoff=0.2):\n",
        "    \"\"\"Cluster fingerprints\n",
        "    Parameters:\n",
        "        fingerprints\n",
        "        cutoff: threshold for the clustering\n",
        "    \"\"\"\n",
        "    # Calculate Tanimoto distance matrix\n",
        "    distance_matrix = tanimoto_distance_matrix(fingerprints)\n",
        "    # Now cluster the data with the implemented Butina algorithm:\n",
        "    clusters = Butina.ClusterData(distance_matrix, len(fingerprints), cutoff, isDistData=True)\n",
        "    clusters = sorted(clusters, key=len, reverse=True)\n",
        "    return clusters"
      ],
      "metadata": {
        "id": "5ZcQJY7MDRxW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Run the clustering procedure for the dataset\n",
        "clusters = cluster_fingerprints(fingerprints, cutoff=0.3)\n",
        "\n",
        "# Give a short report about the numbers of clusters and their sizes\n",
        "num_clust_g1 = sum(1 for c in clusters if len(c) == 1)\n",
        "num_clust_g5 = sum(1 for c in clusters if len(c) > 5)\n",
        "num_clust_g25 = sum(1 for c in clusters if len(c) > 25)\n",
        "num_clust_g100 = sum(1 for c in clusters if len(c) > 100)\n",
        "\n",
        "print(\"total # clusters: \", len(clusters))\n",
        "print(\"# clusters with only 1 compound: \", num_clust_g1)\n",
        "print(\"# clusters with >5 compounds: \", num_clust_g5)\n",
        "print(\"# clusters with >25 compounds: \", num_clust_g25)\n",
        "print(\"# clusters with >100 compounds: \", num_clust_g100)\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "id": "H1XWfxqKDTxN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Plot the size of the clusters\n",
        "fig, ax = plt.subplots(figsize=(15, 4))\n",
        "ax.set_xlabel(\"Cluster index\")\n",
        "ax.set_ylabel(\"Number of molecules\")\n",
        "ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters], lw=5);"
      ],
      "metadata": {
        "id": "lGxKiS9sDWxB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for cutoff in numpy.arange(0.0, 1.0, 0.2):\n",
        "    clusters = cluster_fingerprints(fingerprints, cutoff=cutoff)\n",
        "    fig, ax = plt.subplots(figsize=(15, 4))\n",
        "    ax.set_title(f\"Threshold: {cutoff:3.1f}\")\n",
        "    ax.set_xlabel(\"Cluster index\")\n",
        "    ax.set_ylabel(\"Number of molecules\")\n",
        "    ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters], lw=5)\n",
        "    display(fig)"
      ],
      "metadata": {
        "id": "edbn5RRCDcQ-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "cutoff = 0.2\n",
        "clusters = cluster_fingerprints(fingerprints, cutoff=cutoff)\n",
        "\n",
        "# Plot the size of the clusters - save plot\n",
        "fig, ax = plt.subplots(figsize=(15, 4))\n",
        "ax.set_xlabel(\"Cluster index\")\n",
        "ax.set_ylabel(\"# molecules\")\n",
        "ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters])\n",
        "ax.set_title(f\"Threshold: {cutoff:3.1f}\")\n",
        "fig.savefig(\n",
        "    f\"cluster_dist_cutoff_{cutoff:4.2f}.png\",\n",
        "    dpi=300,\n",
        "    bbox_inches=\"tight\",\n",
        "    transparent=True,\n",
        ")\n",
        "\n",
        "print(\n",
        "    f\"Number of clusters: {len(clusters)} from {len(compounds)} molecules at distance cut-off {cutoff:.2f}\"\n",
        ")\n",
        "print(\"Number of molecules in largest cluster:\", len(clusters[0]))\n",
        "print(\n",
        "    f\"Similarity between two random points in same cluster: {DataStructs.TanimotoSimilarity(fingerprints[clusters[0][0]], fingerprints[clusters[0][1]]):.2f}\"\n",
        ")\n",
        "print(\n",
        "    f\"Similarity between two random points in different cluster: {DataStructs.TanimotoSimilarity(fingerprints[clusters[0][0]], fingerprints[clusters[1][0]]):.2f}\"\n",
        ")"
      ],
      "metadata": {
        "id": "k3lEYETpDswT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Ten molecules from largest cluster:\")\n",
        "# Draw molecules\n",
        "Draw.MolsToGridImage(\n",
        "    [compounds[i][0] for i in clusters[0][:10]],\n",
        "    legends=[compounds[i][1] for i in clusters[0][:10]],\n",
        "    molsPerRow=5,\n",
        ")"
      ],
      "metadata": {
        "id": "rz__tGBhD1F-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Save molecules from largest cluster so other talktorials can use it\n",
        "sdf_path = str(\"molecule_set_largest_cluster.sdf\")\n",
        "sdf = Chem.SDWriter(sdf_path)\n",
        "for index in clusters[0]:\n",
        "    mol, label = compounds[index]\n",
        "    # Add label to metadata\n",
        "    mol.SetProp(\"_Name\", label)\n",
        "    sdf.write(mol)\n",
        "sdf.close()"
      ],
      "metadata": {
        "id": "sRJlL7K2D3zO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Ten molecules from second largest cluster:\")\n",
        "# Draw molecules\n",
        "Draw.MolsToGridImage(\n",
        "    [compounds[i][0] for i in clusters[1][:10]],\n",
        "    legends=[compounds[i][1] for i in clusters[1][:10]],\n",
        "    molsPerRow=5,\n",
        ")"
      ],
      "metadata": {
        "id": "2aD1Wrg8D-lv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Ten molecules from first 10 clusters:\")\n",
        "# Draw molecules\n",
        "Draw.MolsToGridImage(\n",
        "    [compounds[clusters[i][0]][0] for i in range(10)],\n",
        "    legends=[compounds[clusters[i][0]][1] for i in range(10)],\n",
        "    molsPerRow=5,\n",
        ")"
      ],
      "metadata": {
        "id": "oTDrWCFGEGR_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Generate image\n",
        "img = Draw.MolsToGridImage(\n",
        "    [compounds[clusters[i][0]][0] for i in range(0, 3)],\n",
        "    legends=[f\"Cluster {i}\" for i in range(1, 4)],\n",
        "    subImgSize=(200, 200),\n",
        "    useSVG=True,\n",
        ")\n",
        "\n",
        "# Patch RAW svg data: convert non-transparent to transparent background and set font size\n",
        "molsvg = img.replace(\"opacity:1.0\", \"opacity:0.0\").replace(\"12px\", \"20px\")\n",
        "\n",
        "# Save altered SVG data to file\n",
        "with open(\"cluster_representatives.svg\", \"w\") as f:\n",
        "    f.write(molsvg)"
      ],
      "metadata": {
        "id": "TYu5ERd7ENPP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def intra_tanimoto(fps_clusters):\n",
        "    \"\"\"Function to compute Tanimoto similarity for all pairs of fingerprints in each cluster\"\"\"\n",
        "    intra_similarity = []\n",
        "    # Calculate intra similarity per cluster\n",
        "    for cluster in fps_clusters:\n",
        "        # Tanimoto distance matrix function converted to similarity matrix (1-distance)\n",
        "        intra_similarity.append([1 - x for x in tanimoto_distance_matrix(cluster)])\n",
        "    return intra_similarity"
      ],
      "metadata": {
        "id": "SbkjGFAtElV2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Recompute fingerprints for 10 first clusters\n",
        "mol_fps_per_cluster = []\n",
        "for cluster in clusters[:10]:\n",
        "    mol_fps_per_cluster.append([rdkit_gen.GetFingerprint(compounds[i][0]) for i in cluster])\n",
        "\n",
        "# Compute intra-cluster similarity\n",
        "intra_sim = intra_tanimoto(mol_fps_per_cluster)"
      ],
      "metadata": {
        "id": "EB6o0NUvEnK2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Violin plot with intra-cluster similarity\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(10, 5))\n",
        "indices = list(range(10))\n",
        "ax.set_xlabel(\"Cluster index\")\n",
        "ax.set_ylabel(\"Similarity\")\n",
        "ax.set_xticks(indices)\n",
        "ax.set_xticklabels(indices)\n",
        "ax.set_yticks(numpy.arange(0.6, 1.0, 0.1))\n",
        "ax.set_title(\"Intra-cluster Tanimoto similarity\", fontsize=13)\n",
        "r = ax.violinplot(intra_sim, indices, showmeans=True, showmedians=True, showextrema=False)\n",
        "r[\"cmeans\"].set_color(\"red\")\n",
        "# mean=red, median=blue"
      ],
      "metadata": {
        "id": "by8B4MSYEogY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Get the cluster center of each cluster (first molecule in each cluster)\n",
        "cluster_centers = [compounds[c[0]] for c in clusters]\n",
        "# How many cluster centers/clusters do we have?\n",
        "print(\"Number of cluster centers:\", len(cluster_centers))\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "id": "CogA_W54JK2X"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Sort the molecules within a cluster based on their similarity\n",
        "# to the cluster center and sort the clusters based on their size\n",
        "sorted_clusters = []\n",
        "for cluster in clusters:\n",
        "    if len(cluster) <= 1:\n",
        "        continue  # Singletons\n",
        "    # else:\n",
        "    # Compute fingerprints for each cluster element\n",
        "    sorted_fingerprints = [rdkit_gen.GetFingerprint(compounds[i][0]) for i in cluster]\n",
        "    # Similarity of all cluster members to the cluster center\n",
        "    similarities = DataStructs.BulkTanimotoSimilarity(\n",
        "        sorted_fingerprints[0], sorted_fingerprints[1:]\n",
        "    )\n",
        "    # Add index of the molecule to its similarity (centroid excluded!)\n",
        "    similarities = list(zip(similarities, cluster[1:]))\n",
        "    # Sort in descending order by similarity\n",
        "    similarities.sort(reverse=True)\n",
        "    # Save cluster size and index of molecules in clusters_sort\n",
        "    sorted_clusters.append((len(similarities), [i for _, i in similarities]))\n",
        "    # Sort in descending order by cluster size\n",
        "    sorted_clusters.sort(reverse=True)"
      ],
      "metadata": {
        "id": "pskLQ2S9JOnl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Count selected molecules, pick cluster centers first\n",
        "selected_molecules = cluster_centers.copy()\n",
        "# Take 10 molecules (or a maximum of 50%) of each cluster starting with the largest one\n",
        "index = 0\n",
        "pending = 1000 - len(selected_molecules)\n",
        "while pending > 0 and index < len(sorted_clusters):\n",
        "    # Take indices of sorted clusters\n",
        "    tmp_cluster = sorted_clusters[index][1]\n",
        "    # If the first cluster is > 10 big then take exactly 10 compounds\n",
        "    if sorted_clusters[index][0] > 10:\n",
        "        num_compounds = 10\n",
        "    # If smaller, take half of the molecules\n",
        "    else:\n",
        "        num_compounds = int(0.5 * len(tmp_cluster)) + 1\n",
        "    if num_compounds > pending:\n",
        "        num_compounds = pending\n",
        "    # Write picked molecules and their structures into list of lists called picked_fps\n",
        "    selected_molecules += [compounds[i] for i in tmp_cluster[:num_compounds]]\n",
        "    index += 1\n",
        "    pending = 1000 - len(selected_molecules)\n",
        "print(\"# Selected molecules:\", len(selected_molecules))\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "id": "3o9Pwd95JQlt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Reuse old dataset\n",
        "sampled_mols = compounds.copy()"
      ],
      "metadata": {
        "id": "sZJMGoIuJYEN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Helper function for time computation\n",
        "def measure_runtime(sampled_mols):\n",
        "    start_time = time.time()\n",
        "    sampled_fingerprints = [rdkit_gen.GetFingerprint(m) for m, idx in sampled_mols]\n",
        "    # Run the clustering with the dataset\n",
        "    sampled_clusters = cluster_fingerprints(sampled_fingerprints, cutoff=0.3)\n",
        "    return time.time() - start_time"
      ],
      "metadata": {
        "id": "lVBhlkOiJeTd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(sampled_mols)\n",
        "# NBVAL_CHECK_OUTPUT"
      ],
      "metadata": {
        "id": "yofFXJ1vJgrO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "sample_sizes = [100, 500, 1000, 2000, 4000]\n",
        "runtimes = []\n",
        "# Take random samples with replacement\n",
        "for size in sample_sizes:\n",
        "    time_taken = measure_runtime(random.sample(sampled_mols, size))\n",
        "    print(f\"Dataset size {size}, time {time_taken:4.2f} seconds\")\n",
        "    runtimes.append(time_taken)"
      ],
      "metadata": {
        "id": "3yWndEPJJice"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}