--- a
+++ b/ML_5.ipynb
@@ -0,0 +1,604 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "authorship_tag": "ABX9TyOHZYm2RG/vTEZ6jZiHZihY",
+      "include_colab_link": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "view-in-github",
+        "colab_type": "text"
+      },
+      "source": [
+        "<a href=\"https://colab.research.google.com/github/francescopatane96/Computer_aided_drug_discovery_kit/blob/main/ML_5.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "!pip install rdkit "
+      ],
+      "metadata": {
+        "id": "2D4euSqkBVDd"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "!pip install git+https://github.com/volkamerlab/teachopencadd.git"
+      ],
+      "metadata": {
+        "id": "myQilAQCBeKl"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import time\n",
+        "import random\n",
+        "from pathlib import Path\n",
+        "\n",
+        "import pandas as pd\n",
+        "import numpy\n",
+        "import matplotlib.pyplot as plt\n",
+        "from rdkit import Chem\n",
+        "from rdkit import DataStructs\n",
+        "from rdkit.ML.Cluster import Butina\n",
+        "from rdkit.Chem import Draw\n",
+        "from rdkit.Chem import rdFingerprintGenerator\n",
+        "\n",
+        "from teachopencadd.utils import seed_everything\n",
+        "\n",
+        "seed_everything()  # fix seed to get deterministic outputs"
+      ],
+      "metadata": {
+        "id": "GyKKqPxiBTs2"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Load and have a look into data\n",
+        "# Filtered data taken from **Talktorial T002**\n",
+        "compound_df = pd.read_csv(\n",
+        "    \"TNFB_compounds_lipinski.csv\",\n",
+        "    index_col=0,\n",
+        ")\n",
+        "print(\"Dataframe shape:\", compound_df.shape)\n",
+        "compound_df.head()"
+      ],
+      "metadata": {
+        "id": "7hQCDG1LC42R"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Create molecules from SMILES and store in array\n",
+        "compounds = []\n",
+        "# .itertuples() returns a (index, column1, column2, ...) tuple per row\n",
+        "# we don't need index so we use _ instead\n",
+        "# note how we are slicing the dataframe to only the two columns we need now\n",
+        "for _, chembl_id, smiles in compound_df[[\"molecule_chembl_id\", \"smiles\"]].itertuples():\n",
+        "    compounds.append((Chem.MolFromSmiles(smiles), chembl_id))\n",
+        "compounds[:5]"
+      ],
+      "metadata": {
+        "id": "tHvZx_QmDBUu"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Create fingerprints for all molecules\n",
+        "rdkit_gen = rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=5)\n",
+        "fingerprints = [rdkit_gen.GetFingerprint(mol) for mol, idx in compounds]\n",
+        "\n",
+        "# How many compounds/fingerprints do we have?\n",
+        "print(\"Number of compounds converted:\", len(fingerprints))\n",
+        "print(\"Fingerprint length per compound:\", len(fingerprints[0]))\n",
+        "# NBVAL_CHECK_OUTPUT"
+      ],
+      "metadata": {
+        "id": "5y-nTbl3DDc-"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def tanimoto_distance_matrix(fp_list):\n",
+        "    \"\"\"Calculate distance matrix for fingerprint list\"\"\"\n",
+        "    dissimilarity_matrix = []\n",
+        "    # Notice how we are deliberately skipping the first and last items in the list\n",
+        "    # because we don't need to compare them against themselves\n",
+        "    for i in range(1, len(fp_list)):\n",
+        "        # Compare the current fingerprint against all the previous ones in the list\n",
+        "        similarities = DataStructs.BulkTanimotoSimilarity(fp_list[i], fp_list[:i])\n",
+        "        # Since we need a distance matrix, calculate 1-x for every element in similarity matrix\n",
+        "        dissimilarity_matrix.extend([1 - x for x in similarities])\n",
+        "    return dissimilarity_matrix"
+      ],
+      "metadata": {
+        "id": "d9IDVZRADIRu"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Example: Calculate single similarity of two fingerprints\n",
+        "# NBVAL_CHECK_OUTPUT\n",
+        "sim = DataStructs.TanimotoSimilarity(fingerprints[0], fingerprints[1])\n",
+        "print(f\"Tanimoto similarity: {sim:.2f}, distance: {1-sim:.2f}\")"
+      ],
+      "metadata": {
+        "id": "QauDa9HyDKVO"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Example: Calculate distance matrix (distance = 1-similarity)\n",
+        "tanimoto_distance_matrix(fingerprints)[0:5]"
+      ],
+      "metadata": {
+        "id": "yY0K_IF7DMKu"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Side note: That looked like a list and not a matrix.\n",
+        "# But it is a triangular similarity matrix in the form of a list\n",
+        "n = len(fingerprints)\n",
+        "\n",
+        "# Calculate number of elements in triangular matrix via n*(n-1)/2\n",
+        "elem_triangular_matr = (n * (n - 1)) / 2\n",
+        "print(\n",
+        "    f\"Elements in the triangular matrix ({elem_triangular_matr:.0f}) ==\",\n",
+        "    f\"tanimoto_distance_matrix(fingerprints) ({len(tanimoto_distance_matrix(fingerprints))})\",\n",
+        ")\n",
+        "# NBVAL_CHECK_OUTPUT"
+      ],
+      "metadata": {
+        "id": "iccQP_zbDOX-"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def cluster_fingerprints(fingerprints, cutoff=0.2):\n",
+        "    \"\"\"Cluster fingerprints\n",
+        "    Parameters:\n",
+        "        fingerprints\n",
+        "        cutoff: threshold for the clustering\n",
+        "    \"\"\"\n",
+        "    # Calculate Tanimoto distance matrix\n",
+        "    distance_matrix = tanimoto_distance_matrix(fingerprints)\n",
+        "    # Now cluster the data with the implemented Butina algorithm:\n",
+        "    clusters = Butina.ClusterData(distance_matrix, len(fingerprints), cutoff, isDistData=True)\n",
+        "    clusters = sorted(clusters, key=len, reverse=True)\n",
+        "    return clusters"
+      ],
+      "metadata": {
+        "id": "5ZcQJY7MDRxW"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Run the clustering procedure for the dataset\n",
+        "clusters = cluster_fingerprints(fingerprints, cutoff=0.3)\n",
+        "\n",
+        "# Give a short report about the numbers of clusters and their sizes\n",
+        "num_clust_g1 = sum(1 for c in clusters if len(c) == 1)\n",
+        "num_clust_g5 = sum(1 for c in clusters if len(c) > 5)\n",
+        "num_clust_g25 = sum(1 for c in clusters if len(c) > 25)\n",
+        "num_clust_g100 = sum(1 for c in clusters if len(c) > 100)\n",
+        "\n",
+        "print(\"total # clusters: \", len(clusters))\n",
+        "print(\"# clusters with only 1 compound: \", num_clust_g1)\n",
+        "print(\"# clusters with >5 compounds: \", num_clust_g5)\n",
+        "print(\"# clusters with >25 compounds: \", num_clust_g25)\n",
+        "print(\"# clusters with >100 compounds: \", num_clust_g100)\n",
+        "# NBVAL_CHECK_OUTPUT"
+      ],
+      "metadata": {
+        "id": "H1XWfxqKDTxN"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Plot the size of the clusters\n",
+        "fig, ax = plt.subplots(figsize=(15, 4))\n",
+        "ax.set_xlabel(\"Cluster index\")\n",
+        "ax.set_ylabel(\"Number of molecules\")\n",
+        "ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters], lw=5);"
+      ],
+      "metadata": {
+        "id": "lGxKiS9sDWxB"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "for cutoff in numpy.arange(0.0, 1.0, 0.2):\n",
+        "    clusters = cluster_fingerprints(fingerprints, cutoff=cutoff)\n",
+        "    fig, ax = plt.subplots(figsize=(15, 4))\n",
+        "    ax.set_title(f\"Threshold: {cutoff:3.1f}\")\n",
+        "    ax.set_xlabel(\"Cluster index\")\n",
+        "    ax.set_ylabel(\"Number of molecules\")\n",
+        "    ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters], lw=5)\n",
+        "    display(fig)"
+      ],
+      "metadata": {
+        "id": "edbn5RRCDcQ-"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "cutoff = 0.2\n",
+        "clusters = cluster_fingerprints(fingerprints, cutoff=cutoff)\n",
+        "\n",
+        "# Plot the size of the clusters - save plot\n",
+        "fig, ax = plt.subplots(figsize=(15, 4))\n",
+        "ax.set_xlabel(\"Cluster index\")\n",
+        "ax.set_ylabel(\"# molecules\")\n",
+        "ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters])\n",
+        "ax.set_title(f\"Threshold: {cutoff:3.1f}\")\n",
+        "fig.savefig(\n",
+        "    f\"cluster_dist_cutoff_{cutoff:4.2f}.png\",\n",
+        "    dpi=300,\n",
+        "    bbox_inches=\"tight\",\n",
+        "    transparent=True,\n",
+        ")\n",
+        "\n",
+        "print(\n",
+        "    f\"Number of clusters: {len(clusters)} from {len(compounds)} molecules at distance cut-off {cutoff:.2f}\"\n",
+        ")\n",
+        "print(\"Number of molecules in largest cluster:\", len(clusters[0]))\n",
+        "print(\n",
+        "    f\"Similarity between two random points in same cluster: {DataStructs.TanimotoSimilarity(fingerprints[clusters[0][0]], fingerprints[clusters[0][1]]):.2f}\"\n",
+        ")\n",
+        "print(\n",
+        "    f\"Similarity between two random points in different cluster: {DataStructs.TanimotoSimilarity(fingerprints[clusters[0][0]], fingerprints[clusters[1][0]]):.2f}\"\n",
+        ")"
+      ],
+      "metadata": {
+        "id": "k3lEYETpDswT"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "print(\"Ten molecules from largest cluster:\")\n",
+        "# Draw molecules\n",
+        "Draw.MolsToGridImage(\n",
+        "    [compounds[i][0] for i in clusters[0][:10]],\n",
+        "    legends=[compounds[i][1] for i in clusters[0][:10]],\n",
+        "    molsPerRow=5,\n",
+        ")"
+      ],
+      "metadata": {
+        "id": "rz__tGBhD1F-"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Save molecules from largest cluster so other talktorials can use it\n",
+        "sdf_path = str(\"molecule_set_largest_cluster.sdf\")\n",
+        "sdf = Chem.SDWriter(sdf_path)\n",
+        "for index in clusters[0]:\n",
+        "    mol, label = compounds[index]\n",
+        "    # Add label to metadata\n",
+        "    mol.SetProp(\"_Name\", label)\n",
+        "    sdf.write(mol)\n",
+        "sdf.close()"
+      ],
+      "metadata": {
+        "id": "sRJlL7K2D3zO"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "print(\"Ten molecules from second largest cluster:\")\n",
+        "# Draw molecules\n",
+        "Draw.MolsToGridImage(\n",
+        "    [compounds[i][0] for i in clusters[1][:10]],\n",
+        "    legends=[compounds[i][1] for i in clusters[1][:10]],\n",
+        "    molsPerRow=5,\n",
+        ")"
+      ],
+      "metadata": {
+        "id": "2aD1Wrg8D-lv"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "print(\"Ten molecules from first 10 clusters:\")\n",
+        "# Draw molecules\n",
+        "Draw.MolsToGridImage(\n",
+        "    [compounds[clusters[i][0]][0] for i in range(10)],\n",
+        "    legends=[compounds[clusters[i][0]][1] for i in range(10)],\n",
+        "    molsPerRow=5,\n",
+        ")"
+      ],
+      "metadata": {
+        "id": "oTDrWCFGEGR_"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Generate image\n",
+        "img = Draw.MolsToGridImage(\n",
+        "    [compounds[clusters[i][0]][0] for i in range(0, 3)],\n",
+        "    legends=[f\"Cluster {i}\" for i in range(1, 4)],\n",
+        "    subImgSize=(200, 200),\n",
+        "    useSVG=True,\n",
+        ")\n",
+        "\n",
+        "# Patch RAW svg data: convert non-transparent to transparent background and set font size\n",
+        "molsvg = img.replace(\"opacity:1.0\", \"opacity:0.0\").replace(\"12px\", \"20px\")\n",
+        "\n",
+        "# Save altered SVG data to file\n",
+        "with open(\"cluster_representatives.svg\", \"w\") as f:\n",
+        "    f.write(molsvg)"
+      ],
+      "metadata": {
+        "id": "TYu5ERd7ENPP"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def intra_tanimoto(fps_clusters):\n",
+        "    \"\"\"Function to compute Tanimoto similarity for all pairs of fingerprints in each cluster\"\"\"\n",
+        "    intra_similarity = []\n",
+        "    # Calculate intra similarity per cluster\n",
+        "    for cluster in fps_clusters:\n",
+        "        # Tanimoto distance matrix function converted to similarity matrix (1-distance)\n",
+        "        intra_similarity.append([1 - x for x in tanimoto_distance_matrix(cluster)])\n",
+        "    return intra_similarity"
+      ],
+      "metadata": {
+        "id": "SbkjGFAtElV2"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Recompute fingerprints for 10 first clusters\n",
+        "mol_fps_per_cluster = []\n",
+        "for cluster in clusters[:10]:\n",
+        "    mol_fps_per_cluster.append([rdkit_gen.GetFingerprint(compounds[i][0]) for i in cluster])\n",
+        "\n",
+        "# Compute intra-cluster similarity\n",
+        "intra_sim = intra_tanimoto(mol_fps_per_cluster)"
+      ],
+      "metadata": {
+        "id": "EB6o0NUvEnK2"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Violin plot with intra-cluster similarity\n",
+        "\n",
+        "fig, ax = plt.subplots(figsize=(10, 5))\n",
+        "indices = list(range(10))\n",
+        "ax.set_xlabel(\"Cluster index\")\n",
+        "ax.set_ylabel(\"Similarity\")\n",
+        "ax.set_xticks(indices)\n",
+        "ax.set_xticklabels(indices)\n",
+        "ax.set_yticks(numpy.arange(0.6, 1.0, 0.1))\n",
+        "ax.set_title(\"Intra-cluster Tanimoto similarity\", fontsize=13)\n",
+        "r = ax.violinplot(intra_sim, indices, showmeans=True, showmedians=True, showextrema=False)\n",
+        "r[\"cmeans\"].set_color(\"red\")\n",
+        "# mean=red, median=blue"
+      ],
+      "metadata": {
+        "id": "by8B4MSYEogY"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Get the cluster center of each cluster (first molecule in each cluster)\n",
+        "cluster_centers = [compounds[c[0]] for c in clusters]\n",
+        "# How many cluster centers/clusters do we have?\n",
+        "print(\"Number of cluster centers:\", len(cluster_centers))\n",
+        "# NBVAL_CHECK_OUTPUT"
+      ],
+      "metadata": {
+        "id": "CogA_W54JK2X"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Sort the molecules within a cluster based on their similarity\n",
+        "# to the cluster center and sort the clusters based on their size\n",
+        "sorted_clusters = []\n",
+        "for cluster in clusters:\n",
+        "    if len(cluster) <= 1:\n",
+        "        continue  # Singletons\n",
+        "    # else:\n",
+        "    # Compute fingerprints for each cluster element\n",
+        "    sorted_fingerprints = [rdkit_gen.GetFingerprint(compounds[i][0]) for i in cluster]\n",
+        "    # Similarity of all cluster members to the cluster center\n",
+        "    similarities = DataStructs.BulkTanimotoSimilarity(\n",
+        "        sorted_fingerprints[0], sorted_fingerprints[1:]\n",
+        "    )\n",
+        "    # Add index of the molecule to its similarity (centroid excluded!)\n",
+        "    similarities = list(zip(similarities, cluster[1:]))\n",
+        "    # Sort in descending order by similarity\n",
+        "    similarities.sort(reverse=True)\n",
+        "    # Save cluster size and index of molecules in clusters_sort\n",
+        "    sorted_clusters.append((len(similarities), [i for _, i in similarities]))\n",
+        "    # Sort in descending order by cluster size\n",
+        "    sorted_clusters.sort(reverse=True)"
+      ],
+      "metadata": {
+        "id": "pskLQ2S9JOnl"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Count selected molecules, pick cluster centers first\n",
+        "selected_molecules = cluster_centers.copy()\n",
+        "# Take 10 molecules (or a maximum of 50%) of each cluster starting with the largest one\n",
+        "index = 0\n",
+        "pending = 1000 - len(selected_molecules)\n",
+        "while pending > 0 and index < len(sorted_clusters):\n",
+        "    # Take indices of sorted clusters\n",
+        "    tmp_cluster = sorted_clusters[index][1]\n",
+        "    # If the first cluster is > 10 big then take exactly 10 compounds\n",
+        "    if sorted_clusters[index][0] > 10:\n",
+        "        num_compounds = 10\n",
+        "    # If smaller, take half of the molecules\n",
+        "    else:\n",
+        "        num_compounds = int(0.5 * len(tmp_cluster)) + 1\n",
+        "    if num_compounds > pending:\n",
+        "        num_compounds = pending\n",
+        "    # Write picked molecules and their structures into list of lists called picked_fps\n",
+        "    selected_molecules += [compounds[i] for i in tmp_cluster[:num_compounds]]\n",
+        "    index += 1\n",
+        "    pending = 1000 - len(selected_molecules)\n",
+        "print(\"# Selected molecules:\", len(selected_molecules))\n",
+        "# NBVAL_CHECK_OUTPUT"
+      ],
+      "metadata": {
+        "id": "3o9Pwd95JQlt"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Reuse old dataset\n",
+        "sampled_mols = compounds.copy()"
+      ],
+      "metadata": {
+        "id": "sZJMGoIuJYEN"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Helper function for time computation\n",
+        "def measure_runtime(sampled_mols):\n",
+        "    start_time = time.time()\n",
+        "    sampled_fingerprints = [rdkit_gen.GetFingerprint(m) for m, idx in sampled_mols]\n",
+        "    # Run the clustering with the dataset\n",
+        "    sampled_clusters = cluster_fingerprints(sampled_fingerprints, cutoff=0.3)\n",
+        "    return time.time() - start_time"
+      ],
+      "metadata": {
+        "id": "lVBhlkOiJeTd"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "len(sampled_mols)\n",
+        "# NBVAL_CHECK_OUTPUT"
+      ],
+      "metadata": {
+        "id": "yofFXJ1vJgrO"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "sample_sizes = [100, 500, 1000, 2000, 4000]\n",
+        "runtimes = []\n",
+        "# Take random samples with replacement\n",
+        "for size in sample_sizes:\n",
+        "    time_taken = measure_runtime(random.sample(sampled_mols, size))\n",
+        "    print(f\"Dataset size {size}, time {time_taken:4.2f} seconds\")\n",
+        "    runtimes.append(time_taken)"
+      ],
+      "metadata": {
+        "id": "3yWndEPJJice"
+      },
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file