{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 05-10 14:48:43] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.\n" ] }, { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import os\n", "os.chdir('../')\n", "\n", "import DeepPurpose.CompoundPred as property_pred\n", "from DeepPurpose.utils import *\n", "from DeepPurpose.dataset import *\n", "\n", "from sklearn.metrics import mean_squared_error, roc_auc_score, average_precision_score, f1_score\n", "\n", "\n", "import numpy as np\n", "\n", "from ax.plot.contour import interact_contour, plot_contour\n", "from ax.plot.trace import optimization_trace_single_method\n", "from ax.service.managed_loop import optimize\n", "from ax.utils.notebook.plotting import render, init_notebook_plotting\n", "from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN\n", "\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "init_notebook_plotting()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Drug Property Prediction Mode...\n", "in total: 3328 drugs\n", "encoding drug...\n", "unique drugs: 1694\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n", "Drug Property Prediction Mode...\n", "in total: 201 drugs\n", "encoding drug...\n", "unique drugs: 201\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n", "Drug Property Prediction Mode...\n", "in total: 202 drugs\n", "encoding drug...\n", "unique drugs: 202\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n" ] } ], "source": [ "fold_n = 1\n", "balanced = True\n", "train = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/train.csv')\n", "dev = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/dev.csv')\n", "test = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/test.csv')\n", "\n", "if balanced:\n", " # oversample balanced training\n", " train = pd.concat([train[train.activity == 1].sample(n = len(train[train.activity == 0]), replace=True), train[train.activity == 0]]).sample(frac = 1).reset_index(drop = True)\n", "\n", "X_train = train.smiles.values\n", "y_train = train.activity.values\n", "X_dev = dev.smiles.values\n", "y_dev = dev.activity.values\n", "X_test = test.smiles.values\n", "y_test = test.activity.values\n", "\n", "drug_encoding = 'Morgan'\n", "train = data_process(X_drug = X_train, y = y_train, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)\n", "\n", "val = data_process(X_drug = X_dev, y = y_dev, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)\n", "\n", "test = data_process(X_drug = X_test, y = y_test, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def run_Morgan(parameterization): \n", " config = generate_config(drug_encoding = drug_encoding, \n", " cls_hidden_dims = [parameterization['cls_hidden_dims']], \n", " train_epoch = 10, \n", " LR = parameterization['LR'], \n", " batch_size = parameterization['batch_size'],\n", " decay = parameterization['decay']\n", " )\n", " \n", " model = property_pred.model_initialize(**config)\n", " model.train(train, val, test, verbose = False)\n", " \n", " scores = model.predict(test, verbose = False)\n", " return average_precision_score(test.Label.values, scores)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# easy tuning " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[INFO 05-10 14:48:44] ax.modelbridge.dispatch_utils: Using Sobol generation strategy.\n", "[INFO 05-10 14:48:44] ax.service.managed_loop: Started full optimization with 20 steps.\n", "[INFO 05-10 14:48:44] ax.service.managed_loop: Running optimization trial 1...\n", "[INFO 05-10 14:48:57] ax.service.managed_loop: Running optimization trial 2...\n", "[INFO 05-10 14:49:08] ax.service.managed_loop: Running optimization trial 3...\n", "[INFO 05-10 14:49:24] ax.service.managed_loop: Running optimization trial 4...\n", "[INFO 05-10 14:49:35] ax.service.managed_loop: Running optimization trial 5...\n", "[INFO 05-10 14:49:47] ax.service.managed_loop: Running optimization trial 6...\n", "[INFO 05-10 14:50:01] ax.service.managed_loop: Running optimization trial 7...\n", "[INFO 05-10 14:50:20] ax.service.managed_loop: Running optimization trial 8...\n", "[INFO 05-10 14:50:32] ax.service.managed_loop: Running optimization trial 9...\n", "[INFO 05-10 14:50:46] ax.service.managed_loop: Running optimization trial 10...\n", "[INFO 05-10 14:51:00] ax.service.managed_loop: Running optimization trial 11...\n", "[INFO 05-10 14:51:19] ax.service.managed_loop: Running optimization trial 12...\n", "[INFO 05-10 14:51:29] ax.service.managed_loop: Running optimization trial 13...\n", "[INFO 05-10 14:51:44] ax.service.managed_loop: Running optimization trial 14...\n", "[INFO 05-10 14:52:00] ax.service.managed_loop: Running optimization trial 15...\n", "[INFO 05-10 14:52:14] ax.service.managed_loop: Running optimization trial 16...\n", "[INFO 05-10 14:52:32] ax.service.managed_loop: Running optimization trial 17...\n", "[INFO 05-10 14:52:46] ax.service.managed_loop: Running optimization trial 18...\n", "[INFO 05-10 14:52:57] ax.service.managed_loop: Running optimization trial 19...\n", "[INFO 05-10 14:53:13] ax.service.managed_loop: Running optimization trial 20...\n" ] } ], "source": [ "best_parameters, values, experiment, model = optimize(\n", " parameters=[\n", " {\"name\": \"LR\", \"type\": \"range\", \"bounds\": [1e-6, 1e-3], \"log_scale\": False},\n", " {\"name\": \"decay\", \"type\": \"range\", \"bounds\": [0.0, 0.2], \"log_scale\": False},\n", " {\"name\": \"cls_hidden_dims\", \"type\": \"choice\", \"values\": [32, 64, 128, 256, 512]},\n", " {\"name\": \"batch_size\", \"type\": \"choice\", \"values\": [64, 128, 256]}\n", " ],\n", " evaluation_function=run_Morgan,\n", " objective_name='accuracy',\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'LR': 0.00017391057281102986,\n", " 'decay': 0.13253133054822683,\n", " 'cls_hidden_dims': 256,\n", " 'batch_size': 128}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_parameters" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'accuracy': 0.567285785786404}, {'accuracy': {'accuracy': 0.0}})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "means, covariances = values\n", "means, covariances" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "linkText": "Export to plot.ly", "plotlyServerURL": "https://plot.ly", "showLink": false }, "data": [ { "hoverinfo": "none", "legendgroup": "", "line": { "width": 0 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ 54.896028666621696, 54.896028666621696, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395 ] }, { "fill": "tonexty", "fillcolor": "rgba(128,177,211,0.3)", "legendgroup": "mean", "line": { "color": "rgba(128,177,211,1)" }, "mode": "lines", "name": "mean", "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ 54.896028666621696, 54.896028666621696, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395 ] }, { "fill": "tonexty", "fillcolor": "rgba(128,177,211,0.3)", "hoverinfo": "none", "legendgroup": "", "line": { "width": 0 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ 54.896028666621696, 54.896028666621696, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395 ] } ], "layout": { "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Model performance vs. # of iterations" }, "xaxis": { "title": { "text": "Iteration" } }, "yaxis": { "title": { "text": "Classification Accuracy, %" } } } }, "text/html": [ "