--- a
+++ b/ML_2.ipynb
@@ -0,0 +1,433 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "authorship_tag": "ABX9TyMMu4hPOYs2eF6sHNgEIp8p",
+      "include_colab_link": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "view-in-github",
+        "colab_type": "text"
+      },
+      "source": [
+        "<a href=\"https://colab.research.google.com/github/francescopatane96/Computer_aided_drug_discovery_kit/blob/main/ML_2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "!pip install rdkit"
+      ],
+      "metadata": {
+        "id": "1fsxkuZvGP6c"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "UfeulBrhE0uV"
+      },
+      "outputs": [],
+      "source": [
+        "from pathlib import Path\n",
+        "import math\n",
+        "\n",
+        "import numpy as np\n",
+        "import pandas as pd\n",
+        "import matplotlib.pyplot as plt\n",
+        "from matplotlib.lines import Line2D\n",
+        "import matplotlib.patches as mpatches\n",
+        "from rdkit import Chem\n",
+        "from rdkit.Chem import Descriptors, Draw, PandasTools"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "molecules = pd.read_csv(\"TNFB_compounds.csv\", index_col=0)\n",
+        "print(molecules.shape)\n",
+        "molecules.head()"
+      ],
+      "metadata": {
+        "id": "-9yrLGstF253"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def calculate_ro5_properties(smiles):\n",
+        "    \"\"\"\n",
+        "    Test if input molecule (SMILES) fulfills Lipinski's rule of five.\n",
+        "\n",
+        "    Parameters\n",
+        "    ----------\n",
+        "    smiles : str\n",
+        "        SMILES for a molecule.\n",
+        "\n",
+        "    Returns\n",
+        "    -------\n",
+        "    pandas.Series\n",
+        "        Molecular weight, number of hydrogen bond acceptors/donor and logP value\n",
+        "        and Lipinski's rule of five compliance for input molecule.\n",
+        "    \"\"\"\n",
+        "    # RDKit molecule from SMILES\n",
+        "    molecule = Chem.MolFromSmiles(smiles)\n",
+        "    # Calculate Ro5-relevant chemical properties\n",
+        "    molecular_weight = Descriptors.ExactMolWt(molecule)\n",
+        "    n_hba = Descriptors.NumHAcceptors(molecule)\n",
+        "    n_hbd = Descriptors.NumHDonors(molecule)\n",
+        "    logp = Descriptors.MolLogP(molecule)\n",
+        "    # Check if Ro5 conditions fulfilled\n",
+        "    conditions = [molecular_weight <= 500, n_hba <= 10, n_hbd <= 5, logp <= 5]\n",
+        "    ro5_fulfilled = sum(conditions) >= 3\n",
+        "    # Return True if no more than one out of four conditions is violated\n",
+        "    return pd.Series(\n",
+        "        [molecular_weight, n_hba, n_hbd, logp, ro5_fulfilled],\n",
+        "        index=[\"molecular_weight\", \"n_hba\", \"n_hbd\", \"logp\", \"ro5_fulfilled\"],\n",
+        "    )"
+      ],
+      "metadata": {
+        "id": "k2R7G3g6Gj59"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "ro5_properties = molecules[\"smiles\"].apply(calculate_ro5_properties)\n",
+        "ro5_properties.head()"
+      ],
+      "metadata": {
+        "id": "fmgctHtRGc69"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "molecules = pd.concat([molecules, ro5_properties], axis=1)\n",
+        "molecules.head()"
+      ],
+      "metadata": {
+        "id": "alCYp2v4GwmV"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "molecules_ro5_fulfilled = molecules[molecules[\"ro5_fulfilled\"]]\n",
+        "molecules_ro5_violated = molecules[~molecules[\"ro5_fulfilled\"]]\n",
+        "\n",
+        "print(f\"# compounds in unfiltered data set: {molecules.shape[0]}\")\n",
+        "print(f\"# compounds in filtered data set: {molecules_ro5_fulfilled.shape[0]}\")\n",
+        "print(f\"# compounds not compliant with the Ro5: {molecules_ro5_violated.shape[0]}\")"
+      ],
+      "metadata": {
+        "id": "VoegvlyJHI4k"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "molecules_ro5_fulfilled.to_csv(\"TNFB_compounds_lipinski.csv\")\n",
+        "molecules_ro5_fulfilled.head()"
+      ],
+      "metadata": {
+        "id": "7VM0nYh2HPIU"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def calculate_mean_std(dataframe):\n",
+        "    \"\"\"\n",
+        "    Calculate the mean and standard deviation of a dataset.\n",
+        "\n",
+        "    Parameters\n",
+        "    ----------\n",
+        "    dataframe : pd.DataFrame\n",
+        "        Properties (columns) for a set of items (rows).\n",
+        "\n",
+        "    Returns\n",
+        "    -------\n",
+        "    pd.DataFrame\n",
+        "        Mean and standard deviation (columns) for different properties (rows).\n",
+        "    \"\"\"\n",
+        "    # Generate descriptive statistics for property columns\n",
+        "    stats = dataframe.describe()\n",
+        "    # Transpose DataFrame (statistical measures = columns)\n",
+        "    stats = stats.T\n",
+        "    # Select mean and standard deviation\n",
+        "    stats = stats[[\"mean\", \"std\"]]\n",
+        "    return stats"
+      ],
+      "metadata": {
+        "id": "6cyggR4kHeD-"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "molecules_ro5_fulfilled_stats = calculate_mean_std(\n",
+        "    molecules_ro5_fulfilled[[\"molecular_weight\", \"n_hba\", \"n_hbd\", \"logp\"]]\n",
+        ")\n",
+        "molecules_ro5_fulfilled_stats"
+      ],
+      "metadata": {
+        "id": "jNn09KprHsl1"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "molecules_ro5_violated_stats = calculate_mean_std(\n",
+        "    molecules_ro5_violated[[\"molecular_weight\", \"n_hba\", \"n_hbd\", \"logp\"]]\n",
+        ")\n",
+        "molecules_ro5_violated_stats"
+      ],
+      "metadata": {
+        "id": "AkiyjrV-IDYt"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def _scale_by_thresholds(stats, thresholds, scaled_threshold):\n",
+        "    \"\"\"\n",
+        "    Scale values for different properties that have each an individually defined threshold.\n",
+        "\n",
+        "    Parameters\n",
+        "    ----------\n",
+        "    stats : pd.DataFrame\n",
+        "        Dataframe with \"mean\" and \"std\" (columns) for each physicochemical property (rows).\n",
+        "    thresholds : dict of str: int\n",
+        "        Thresholds defined for each property.\n",
+        "    scaled_threshold : int or float\n",
+        "        Scaled thresholds across all properties.\n",
+        "\n",
+        "    Returns\n",
+        "    -------\n",
+        "    pd.DataFrame\n",
+        "        DataFrame with scaled means and standard deviations for each physiochemical property.\n",
+        "    \"\"\"\n",
+        "    # Raise error if scaling keys and data_stats indicies are not matching\n",
+        "    for property_name in stats.index:\n",
+        "        if property_name not in thresholds.keys():\n",
+        "            raise KeyError(f\"Add property '{property_name}' to scaling variable.\")\n",
+        "    # Scale property data\n",
+        "    stats_scaled = stats.apply(lambda x: x / thresholds[x.name] * scaled_threshold, axis=1)\n",
+        "    return stats_scaled"
+      ],
+      "metadata": {
+        "id": "KTnKx1drIKZM"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def _define_radial_axes_angles(n_axes):\n",
+        "    \"\"\"Define angles (radians) for radial (x-)axes depending on the number of axes.\"\"\"\n",
+        "    x_angles = [i / float(n_axes) * 2 * math.pi for i in range(n_axes)]\n",
+        "    x_angles += x_angles[:1]\n",
+        "    return x_angles"
+      ],
+      "metadata": {
+        "id": "oZj3q_HuIRNM"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def plot_radar(\n",
+        "    y,\n",
+        "    thresholds,\n",
+        "    scaled_threshold,\n",
+        "    properties_labels,\n",
+        "    y_max=None,\n",
+        "    output_path=None,\n",
+        "):\n",
+        "    \"\"\"\n",
+        "    Plot a radar chart based on the mean and standard deviation of a data set's properties.\n",
+        "\n",
+        "    Parameters\n",
+        "    ----------\n",
+        "    y : pd.DataFrame\n",
+        "        Dataframe with \"mean\" and \"std\" (columns) for each physicochemical property (rows).\n",
+        "    thresholds : dict of str: int\n",
+        "        Thresholds defined for each property.\n",
+        "    scaled_threshold : int or float\n",
+        "        Scaled thresholds across all properties.\n",
+        "    properties_labels : list of str\n",
+        "        List of property names to be used as labels in the plot.\n",
+        "    y_max : None or int or float\n",
+        "        Set maximum y value. If None, let matplotlib decide.\n",
+        "    output_path : None or pathlib.Path\n",
+        "        If not None, save plot to file.\n",
+        "    \"\"\"\n",
+        "\n",
+        "    # Define radial x-axes angles -- uses our helper function!\n",
+        "    x = _define_radial_axes_angles(len(y))\n",
+        "    # Scale y-axis values with respect to a defined threshold -- uses our helper function!\n",
+        "    y = _scale_by_thresholds(y, thresholds, scaled_threshold)\n",
+        "    # Since our chart will be circular we append the first value of each property to the end\n",
+        "    y = y.append(y.iloc[0])\n",
+        "\n",
+        "    # Set figure and subplot axis\n",
+        "    plt.figure(figsize=(6, 6))\n",
+        "    ax = plt.subplot(111, polar=True)\n",
+        "\n",
+        "    # Plot data\n",
+        "    ax.fill(x, [scaled_threshold] * 5, \"cornflowerblue\", alpha=0.2)\n",
+        "    ax.plot(x, y[\"mean\"], \"b\", lw=3, ls=\"-\")\n",
+        "    ax.plot(x, y[\"mean\"] + y[\"std\"], \"orange\", lw=2, ls=\"--\")\n",
+        "    ax.plot(x, y[\"mean\"] - y[\"std\"], \"orange\", lw=2, ls=\"-.\")\n",
+        "\n",
+        "    # From here on, we only do plot cosmetics\n",
+        "    # Set 0° to 12 o'clock\n",
+        "    ax.set_theta_offset(math.pi / 2)\n",
+        "    # Set clockwise rotation\n",
+        "    ax.set_theta_direction(-1)\n",
+        "\n",
+        "    # Set y-labels next to 180° radius axis\n",
+        "    ax.set_rlabel_position(180)\n",
+        "    # Set number of radial axes' ticks and remove labels\n",
+        "    plt.xticks(x, [])\n",
+        "    # Get maximal y-ticks value\n",
+        "    if not y_max:\n",
+        "        y_max = int(ax.get_yticks()[-1])\n",
+        "    # Set axes limits\n",
+        "    plt.ylim(0, y_max)\n",
+        "    # Set number and labels of y axis ticks\n",
+        "    plt.yticks(\n",
+        "        range(1, y_max),\n",
+        "        [\"5\" if i == scaled_threshold else \"\" for i in range(1, y_max)],\n",
+        "        fontsize=16,\n",
+        "    )\n",
+        "\n",
+        "    # Draw ytick labels to make sure they fit properly\n",
+        "    # Note that we use [:1] to exclude the last element which equals the first element (not needed here)\n",
+        "    for i, (angle, label) in enumerate(zip(x[:-1], properties_labels)):\n",
+        "        if angle == 0:\n",
+        "            ha = \"center\"\n",
+        "        elif 0 < angle < math.pi:\n",
+        "            ha = \"left\"\n",
+        "        elif angle == math.pi:\n",
+        "            ha = \"center\"\n",
+        "        else:\n",
+        "            ha = \"right\"\n",
+        "        ax.text(\n",
+        "            x=angle,\n",
+        "            y=y_max + 1,\n",
+        "            s=label,\n",
+        "            size=16,\n",
+        "            horizontalalignment=ha,\n",
+        "            verticalalignment=\"center\",\n",
+        "        )\n",
+        "\n",
+        "    # Add legend relative to top-left plot\n",
+        "    labels = (\"mean\", \"mean + std\", \"mean - std\", \"rule of five area\")\n",
+        "    ax.legend(labels, loc=(1.1, 0.7), labelspacing=0.3, fontsize=16)\n",
+        "\n",
+        "    # Save plot - use bbox_inches to include text boxes\n",
+        "    if output_path:\n",
+        "        plt.savefig(output_path, dpi=300, bbox_inches=\"tight\", transparent=True)\n",
+        "\n",
+        "    plt.show()"
+      ],
+      "metadata": {
+        "id": "t6UGc-35Iadf"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "thresholds = {\"molecular_weight\": 500, \"n_hba\": 10, \"n_hbd\": 5, \"logp\": 5}\n",
+        "scaled_threshold = 5\n",
+        "properties_labels = [\n",
+        "    \"Molecular weight (Da) / 100\",\n",
+        "    \"# HBA / 2\",\n",
+        "    \"# HBD\",\n",
+        "    \"LogP\",\n",
+        "]\n",
+        "y_max = 8"
+      ],
+      "metadata": {
+        "id": "2VD_HKI4IfKA"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "plot_radar(\n",
+        "    molecules_ro5_fulfilled_stats,\n",
+        "    thresholds,\n",
+        "    scaled_threshold,\n",
+        "    properties_labels,\n",
+        "    y_max,\n",
+        ")"
+      ],
+      "metadata": {
+        "id": "S5xKT3cRIh6u"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "plot_radar(\n",
+        "    molecules_ro5_violated_stats,\n",
+        "    thresholds,\n",
+        "    scaled_threshold,\n",
+        "    properties_labels,\n",
+        "    y_max,\n",
+        ")"
+      ],
+      "metadata": {
+        "id": "0Z_ToLjpI7Vs"
+      },
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file