--- a +++ b/ML_5.ipynb @@ -0,0 +1,604 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyOHZYm2RG/vTEZ6jZiHZihY", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "<a href=\"https://colab.research.google.com/github/francescopatane96/Computer_aided_drug_discovery_kit/blob/main/ML_5.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install rdkit " + ], + "metadata": { + "id": "2D4euSqkBVDd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install git+https://github.com/volkamerlab/teachopencadd.git" + ], + "metadata": { + "id": "myQilAQCBeKl" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import time\n", + "import random\n", + "from pathlib import Path\n", + "\n", + "import pandas as pd\n", + "import numpy\n", + "import matplotlib.pyplot as plt\n", + "from rdkit import Chem\n", + "from rdkit import DataStructs\n", + "from rdkit.ML.Cluster import Butina\n", + "from rdkit.Chem import Draw\n", + "from rdkit.Chem import rdFingerprintGenerator\n", + "\n", + "from teachopencadd.utils import seed_everything\n", + "\n", + "seed_everything() # fix seed to get deterministic outputs" + ], + "metadata": { + "id": "GyKKqPxiBTs2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Load and have a look into data\n", + "# Filtered data taken from **Talktorial T002**\n", + "compound_df = pd.read_csv(\n", + " \"TNFB_compounds_lipinski.csv\",\n", + " index_col=0,\n", + ")\n", + "print(\"Dataframe shape:\", compound_df.shape)\n", + "compound_df.head()" + ], + "metadata": { + "id": "7hQCDG1LC42R" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Create molecules from SMILES and store in array\n", + "compounds = []\n", + "# .itertuples() returns a (index, column1, column2, ...) tuple per row\n", + "# we don't need index so we use _ instead\n", + "# note how we are slicing the dataframe to only the two columns we need now\n", + "for _, chembl_id, smiles in compound_df[[\"molecule_chembl_id\", \"smiles\"]].itertuples():\n", + " compounds.append((Chem.MolFromSmiles(smiles), chembl_id))\n", + "compounds[:5]" + ], + "metadata": { + "id": "tHvZx_QmDBUu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Create fingerprints for all molecules\n", + "rdkit_gen = rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=5)\n", + "fingerprints = [rdkit_gen.GetFingerprint(mol) for mol, idx in compounds]\n", + "\n", + "# How many compounds/fingerprints do we have?\n", + "print(\"Number of compounds converted:\", len(fingerprints))\n", + "print(\"Fingerprint length per compound:\", len(fingerprints[0]))\n", + "# NBVAL_CHECK_OUTPUT" + ], + "metadata": { + "id": "5y-nTbl3DDc-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def tanimoto_distance_matrix(fp_list):\n", + " \"\"\"Calculate distance matrix for fingerprint list\"\"\"\n", + " dissimilarity_matrix = []\n", + " # Notice how we are deliberately skipping the first and last items in the list\n", + " # because we don't need to compare them against themselves\n", + " for i in range(1, len(fp_list)):\n", + " # Compare the current fingerprint against all the previous ones in the list\n", + " similarities = DataStructs.BulkTanimotoSimilarity(fp_list[i], fp_list[:i])\n", + " # Since we need a distance matrix, calculate 1-x for every element in similarity matrix\n", + " dissimilarity_matrix.extend([1 - x for x in similarities])\n", + " return dissimilarity_matrix" + ], + "metadata": { + "id": "d9IDVZRADIRu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Example: Calculate single similarity of two fingerprints\n", + "# NBVAL_CHECK_OUTPUT\n", + "sim = DataStructs.TanimotoSimilarity(fingerprints[0], fingerprints[1])\n", + "print(f\"Tanimoto similarity: {sim:.2f}, distance: {1-sim:.2f}\")" + ], + "metadata": { + "id": "QauDa9HyDKVO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Example: Calculate distance matrix (distance = 1-similarity)\n", + "tanimoto_distance_matrix(fingerprints)[0:5]" + ], + "metadata": { + "id": "yY0K_IF7DMKu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Side note: That looked like a list and not a matrix.\n", + "# But it is a triangular similarity matrix in the form of a list\n", + "n = len(fingerprints)\n", + "\n", + "# Calculate number of elements in triangular matrix via n*(n-1)/2\n", + "elem_triangular_matr = (n * (n - 1)) / 2\n", + "print(\n", + " f\"Elements in the triangular matrix ({elem_triangular_matr:.0f}) ==\",\n", + " f\"tanimoto_distance_matrix(fingerprints) ({len(tanimoto_distance_matrix(fingerprints))})\",\n", + ")\n", + "# NBVAL_CHECK_OUTPUT" + ], + "metadata": { + "id": "iccQP_zbDOX-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def cluster_fingerprints(fingerprints, cutoff=0.2):\n", + " \"\"\"Cluster fingerprints\n", + " Parameters:\n", + " fingerprints\n", + " cutoff: threshold for the clustering\n", + " \"\"\"\n", + " # Calculate Tanimoto distance matrix\n", + " distance_matrix = tanimoto_distance_matrix(fingerprints)\n", + " # Now cluster the data with the implemented Butina algorithm:\n", + " clusters = Butina.ClusterData(distance_matrix, len(fingerprints), cutoff, isDistData=True)\n", + " clusters = sorted(clusters, key=len, reverse=True)\n", + " return clusters" + ], + "metadata": { + "id": "5ZcQJY7MDRxW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Run the clustering procedure for the dataset\n", + "clusters = cluster_fingerprints(fingerprints, cutoff=0.3)\n", + "\n", + "# Give a short report about the numbers of clusters and their sizes\n", + "num_clust_g1 = sum(1 for c in clusters if len(c) == 1)\n", + "num_clust_g5 = sum(1 for c in clusters if len(c) > 5)\n", + "num_clust_g25 = sum(1 for c in clusters if len(c) > 25)\n", + "num_clust_g100 = sum(1 for c in clusters if len(c) > 100)\n", + "\n", + "print(\"total # clusters: \", len(clusters))\n", + "print(\"# clusters with only 1 compound: \", num_clust_g1)\n", + "print(\"# clusters with >5 compounds: \", num_clust_g5)\n", + "print(\"# clusters with >25 compounds: \", num_clust_g25)\n", + "print(\"# clusters with >100 compounds: \", num_clust_g100)\n", + "# NBVAL_CHECK_OUTPUT" + ], + "metadata": { + "id": "H1XWfxqKDTxN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Plot the size of the clusters\n", + "fig, ax = plt.subplots(figsize=(15, 4))\n", + "ax.set_xlabel(\"Cluster index\")\n", + "ax.set_ylabel(\"Number of molecules\")\n", + "ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters], lw=5);" + ], + "metadata": { + "id": "lGxKiS9sDWxB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for cutoff in numpy.arange(0.0, 1.0, 0.2):\n", + " clusters = cluster_fingerprints(fingerprints, cutoff=cutoff)\n", + " fig, ax = plt.subplots(figsize=(15, 4))\n", + " ax.set_title(f\"Threshold: {cutoff:3.1f}\")\n", + " ax.set_xlabel(\"Cluster index\")\n", + " ax.set_ylabel(\"Number of molecules\")\n", + " ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters], lw=5)\n", + " display(fig)" + ], + "metadata": { + "id": "edbn5RRCDcQ-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "cutoff = 0.2\n", + "clusters = cluster_fingerprints(fingerprints, cutoff=cutoff)\n", + "\n", + "# Plot the size of the clusters - save plot\n", + "fig, ax = plt.subplots(figsize=(15, 4))\n", + "ax.set_xlabel(\"Cluster index\")\n", + "ax.set_ylabel(\"# molecules\")\n", + "ax.bar(range(1, len(clusters) + 1), [len(c) for c in clusters])\n", + "ax.set_title(f\"Threshold: {cutoff:3.1f}\")\n", + "fig.savefig(\n", + " f\"cluster_dist_cutoff_{cutoff:4.2f}.png\",\n", + " dpi=300,\n", + " bbox_inches=\"tight\",\n", + " transparent=True,\n", + ")\n", + "\n", + "print(\n", + " f\"Number of clusters: {len(clusters)} from {len(compounds)} molecules at distance cut-off {cutoff:.2f}\"\n", + ")\n", + "print(\"Number of molecules in largest cluster:\", len(clusters[0]))\n", + "print(\n", + " f\"Similarity between two random points in same cluster: {DataStructs.TanimotoSimilarity(fingerprints[clusters[0][0]], fingerprints[clusters[0][1]]):.2f}\"\n", + ")\n", + "print(\n", + " f\"Similarity between two random points in different cluster: {DataStructs.TanimotoSimilarity(fingerprints[clusters[0][0]], fingerprints[clusters[1][0]]):.2f}\"\n", + ")" + ], + "metadata": { + "id": "k3lEYETpDswT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"Ten molecules from largest cluster:\")\n", + "# Draw molecules\n", + "Draw.MolsToGridImage(\n", + " [compounds[i][0] for i in clusters[0][:10]],\n", + " legends=[compounds[i][1] for i in clusters[0][:10]],\n", + " molsPerRow=5,\n", + ")" + ], + "metadata": { + "id": "rz__tGBhD1F-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Save molecules from largest cluster so other talktorials can use it\n", + "sdf_path = str(\"molecule_set_largest_cluster.sdf\")\n", + "sdf = Chem.SDWriter(sdf_path)\n", + "for index in clusters[0]:\n", + " mol, label = compounds[index]\n", + " # Add label to metadata\n", + " mol.SetProp(\"_Name\", label)\n", + " sdf.write(mol)\n", + "sdf.close()" + ], + "metadata": { + "id": "sRJlL7K2D3zO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"Ten molecules from second largest cluster:\")\n", + "# Draw molecules\n", + "Draw.MolsToGridImage(\n", + " [compounds[i][0] for i in clusters[1][:10]],\n", + " legends=[compounds[i][1] for i in clusters[1][:10]],\n", + " molsPerRow=5,\n", + ")" + ], + "metadata": { + "id": "2aD1Wrg8D-lv" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"Ten molecules from first 10 clusters:\")\n", + "# Draw molecules\n", + "Draw.MolsToGridImage(\n", + " [compounds[clusters[i][0]][0] for i in range(10)],\n", + " legends=[compounds[clusters[i][0]][1] for i in range(10)],\n", + " molsPerRow=5,\n", + ")" + ], + "metadata": { + "id": "oTDrWCFGEGR_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Generate image\n", + "img = Draw.MolsToGridImage(\n", + " [compounds[clusters[i][0]][0] for i in range(0, 3)],\n", + " legends=[f\"Cluster {i}\" for i in range(1, 4)],\n", + " subImgSize=(200, 200),\n", + " useSVG=True,\n", + ")\n", + "\n", + "# Patch RAW svg data: convert non-transparent to transparent background and set font size\n", + "molsvg = img.replace(\"opacity:1.0\", \"opacity:0.0\").replace(\"12px\", \"20px\")\n", + "\n", + "# Save altered SVG data to file\n", + "with open(\"cluster_representatives.svg\", \"w\") as f:\n", + " f.write(molsvg)" + ], + "metadata": { + "id": "TYu5ERd7ENPP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def intra_tanimoto(fps_clusters):\n", + " \"\"\"Function to compute Tanimoto similarity for all pairs of fingerprints in each cluster\"\"\"\n", + " intra_similarity = []\n", + " # Calculate intra similarity per cluster\n", + " for cluster in fps_clusters:\n", + " # Tanimoto distance matrix function converted to similarity matrix (1-distance)\n", + " intra_similarity.append([1 - x for x in tanimoto_distance_matrix(cluster)])\n", + " return intra_similarity" + ], + "metadata": { + "id": "SbkjGFAtElV2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Recompute fingerprints for 10 first clusters\n", + "mol_fps_per_cluster = []\n", + "for cluster in clusters[:10]:\n", + " mol_fps_per_cluster.append([rdkit_gen.GetFingerprint(compounds[i][0]) for i in cluster])\n", + "\n", + "# Compute intra-cluster similarity\n", + "intra_sim = intra_tanimoto(mol_fps_per_cluster)" + ], + "metadata": { + "id": "EB6o0NUvEnK2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Violin plot with intra-cluster similarity\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 5))\n", + "indices = list(range(10))\n", + "ax.set_xlabel(\"Cluster index\")\n", + "ax.set_ylabel(\"Similarity\")\n", + "ax.set_xticks(indices)\n", + "ax.set_xticklabels(indices)\n", + "ax.set_yticks(numpy.arange(0.6, 1.0, 0.1))\n", + "ax.set_title(\"Intra-cluster Tanimoto similarity\", fontsize=13)\n", + "r = ax.violinplot(intra_sim, indices, showmeans=True, showmedians=True, showextrema=False)\n", + "r[\"cmeans\"].set_color(\"red\")\n", + "# mean=red, median=blue" + ], + "metadata": { + "id": "by8B4MSYEogY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Get the cluster center of each cluster (first molecule in each cluster)\n", + "cluster_centers = [compounds[c[0]] for c in clusters]\n", + "# How many cluster centers/clusters do we have?\n", + "print(\"Number of cluster centers:\", len(cluster_centers))\n", + "# NBVAL_CHECK_OUTPUT" + ], + "metadata": { + "id": "CogA_W54JK2X" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Sort the molecules within a cluster based on their similarity\n", + "# to the cluster center and sort the clusters based on their size\n", + "sorted_clusters = []\n", + "for cluster in clusters:\n", + " if len(cluster) <= 1:\n", + " continue # Singletons\n", + " # else:\n", + " # Compute fingerprints for each cluster element\n", + " sorted_fingerprints = [rdkit_gen.GetFingerprint(compounds[i][0]) for i in cluster]\n", + " # Similarity of all cluster members to the cluster center\n", + " similarities = DataStructs.BulkTanimotoSimilarity(\n", + " sorted_fingerprints[0], sorted_fingerprints[1:]\n", + " )\n", + " # Add index of the molecule to its similarity (centroid excluded!)\n", + " similarities = list(zip(similarities, cluster[1:]))\n", + " # Sort in descending order by similarity\n", + " similarities.sort(reverse=True)\n", + " # Save cluster size and index of molecules in clusters_sort\n", + " sorted_clusters.append((len(similarities), [i for _, i in similarities]))\n", + " # Sort in descending order by cluster size\n", + " sorted_clusters.sort(reverse=True)" + ], + "metadata": { + "id": "pskLQ2S9JOnl" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Count selected molecules, pick cluster centers first\n", + "selected_molecules = cluster_centers.copy()\n", + "# Take 10 molecules (or a maximum of 50%) of each cluster starting with the largest one\n", + "index = 0\n", + "pending = 1000 - len(selected_molecules)\n", + "while pending > 0 and index < len(sorted_clusters):\n", + " # Take indices of sorted clusters\n", + " tmp_cluster = sorted_clusters[index][1]\n", + " # If the first cluster is > 10 big then take exactly 10 compounds\n", + " if sorted_clusters[index][0] > 10:\n", + " num_compounds = 10\n", + " # If smaller, take half of the molecules\n", + " else:\n", + " num_compounds = int(0.5 * len(tmp_cluster)) + 1\n", + " if num_compounds > pending:\n", + " num_compounds = pending\n", + " # Write picked molecules and their structures into list of lists called picked_fps\n", + " selected_molecules += [compounds[i] for i in tmp_cluster[:num_compounds]]\n", + " index += 1\n", + " pending = 1000 - len(selected_molecules)\n", + "print(\"# Selected molecules:\", len(selected_molecules))\n", + "# NBVAL_CHECK_OUTPUT" + ], + "metadata": { + "id": "3o9Pwd95JQlt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Reuse old dataset\n", + "sampled_mols = compounds.copy()" + ], + "metadata": { + "id": "sZJMGoIuJYEN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Helper function for time computation\n", + "def measure_runtime(sampled_mols):\n", + " start_time = time.time()\n", + " sampled_fingerprints = [rdkit_gen.GetFingerprint(m) for m, idx in sampled_mols]\n", + " # Run the clustering with the dataset\n", + " sampled_clusters = cluster_fingerprints(sampled_fingerprints, cutoff=0.3)\n", + " return time.time() - start_time" + ], + "metadata": { + "id": "lVBhlkOiJeTd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "len(sampled_mols)\n", + "# NBVAL_CHECK_OUTPUT" + ], + "metadata": { + "id": "yofFXJ1vJgrO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sample_sizes = [100, 500, 1000, 2000, 4000]\n", + "runtimes = []\n", + "# Take random samples with replacement\n", + "for size in sample_sizes:\n", + " time_taken = measure_runtime(random.sample(sampled_mols, size))\n", + " print(f\"Dataset size {size}, time {time_taken:4.2f} seconds\")\n", + " runtimes.append(time_taken)" + ], + "metadata": { + "id": "3yWndEPJJice" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file