{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 05-10 14:48:43] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.\n" ] }, { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import os\n", "os.chdir('../')\n", "\n", "import DeepPurpose.CompoundPred as property_pred\n", "from DeepPurpose.utils import *\n", "from DeepPurpose.dataset import *\n", "\n", "from sklearn.metrics import mean_squared_error, roc_auc_score, average_precision_score, f1_score\n", "\n", "\n", "import numpy as np\n", "\n", "from ax.plot.contour import interact_contour, plot_contour\n", "from ax.plot.trace import optimization_trace_single_method\n", "from ax.service.managed_loop import optimize\n", "from ax.utils.notebook.plotting import render, init_notebook_plotting\n", "from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN\n", "\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "init_notebook_plotting()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Drug Property Prediction Mode...\n", "in total: 3328 drugs\n", "encoding drug...\n", "unique drugs: 1694\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n", "Drug Property Prediction Mode...\n", "in total: 201 drugs\n", "encoding drug...\n", "unique drugs: 201\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n", "Drug Property Prediction Mode...\n", "in total: 202 drugs\n", "encoding drug...\n", "unique drugs: 202\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n" ] } ], "source": [ "fold_n = 1\n", "balanced = True\n", "train = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/train.csv')\n", "dev = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/dev.csv')\n", "test = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/test.csv')\n", "\n", "if balanced:\n", " # oversample balanced training\n", " train = pd.concat([train[train.activity == 1].sample(n = len(train[train.activity == 0]), replace=True), train[train.activity == 0]]).sample(frac = 1).reset_index(drop = True)\n", "\n", "X_train = train.smiles.values\n", "y_train = train.activity.values\n", "X_dev = dev.smiles.values\n", "y_dev = dev.activity.values\n", "X_test = test.smiles.values\n", "y_test = test.activity.values\n", "\n", "drug_encoding = 'Morgan'\n", "train = data_process(X_drug = X_train, y = y_train, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)\n", "\n", "val = data_process(X_drug = X_dev, y = y_dev, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)\n", "\n", "test = data_process(X_drug = X_test, y = y_test, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def run_Morgan(parameterization): \n", " config = generate_config(drug_encoding = drug_encoding, \n", " cls_hidden_dims = [parameterization['cls_hidden_dims']], \n", " train_epoch = 10, \n", " LR = parameterization['LR'], \n", " batch_size = parameterization['batch_size'],\n", " decay = parameterization['decay']\n", " )\n", " \n", " model = property_pred.model_initialize(**config)\n", " model.train(train, val, test, verbose = False)\n", " \n", " scores = model.predict(test, verbose = False)\n", " return average_precision_score(test.Label.values, scores)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# easy tuning " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[INFO 05-10 14:48:44] ax.modelbridge.dispatch_utils: Using Sobol generation strategy.\n", "[INFO 05-10 14:48:44] ax.service.managed_loop: Started full optimization with 20 steps.\n", "[INFO 05-10 14:48:44] ax.service.managed_loop: Running optimization trial 1...\n", "[INFO 05-10 14:48:57] ax.service.managed_loop: Running optimization trial 2...\n", "[INFO 05-10 14:49:08] ax.service.managed_loop: Running optimization trial 3...\n", "[INFO 05-10 14:49:24] ax.service.managed_loop: Running optimization trial 4...\n", "[INFO 05-10 14:49:35] ax.service.managed_loop: Running optimization trial 5...\n", "[INFO 05-10 14:49:47] ax.service.managed_loop: Running optimization trial 6...\n", "[INFO 05-10 14:50:01] ax.service.managed_loop: Running optimization trial 7...\n", "[INFO 05-10 14:50:20] ax.service.managed_loop: Running optimization trial 8...\n", "[INFO 05-10 14:50:32] ax.service.managed_loop: Running optimization trial 9...\n", "[INFO 05-10 14:50:46] ax.service.managed_loop: Running optimization trial 10...\n", "[INFO 05-10 14:51:00] ax.service.managed_loop: Running optimization trial 11...\n", "[INFO 05-10 14:51:19] ax.service.managed_loop: Running optimization trial 12...\n", "[INFO 05-10 14:51:29] ax.service.managed_loop: Running optimization trial 13...\n", "[INFO 05-10 14:51:44] ax.service.managed_loop: Running optimization trial 14...\n", "[INFO 05-10 14:52:00] ax.service.managed_loop: Running optimization trial 15...\n", "[INFO 05-10 14:52:14] ax.service.managed_loop: Running optimization trial 16...\n", "[INFO 05-10 14:52:32] ax.service.managed_loop: Running optimization trial 17...\n", "[INFO 05-10 14:52:46] ax.service.managed_loop: Running optimization trial 18...\n", "[INFO 05-10 14:52:57] ax.service.managed_loop: Running optimization trial 19...\n", "[INFO 05-10 14:53:13] ax.service.managed_loop: Running optimization trial 20...\n" ] } ], "source": [ "best_parameters, values, experiment, model = optimize(\n", " parameters=[\n", " {\"name\": \"LR\", \"type\": \"range\", \"bounds\": [1e-6, 1e-3], \"log_scale\": False},\n", " {\"name\": \"decay\", \"type\": \"range\", \"bounds\": [0.0, 0.2], \"log_scale\": False},\n", " {\"name\": \"cls_hidden_dims\", \"type\": \"choice\", \"values\": [32, 64, 128, 256, 512]},\n", " {\"name\": \"batch_size\", \"type\": \"choice\", \"values\": [64, 128, 256]}\n", " ],\n", " evaluation_function=run_Morgan,\n", " objective_name='accuracy',\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'LR': 0.00017391057281102986,\n", " 'decay': 0.13253133054822683,\n", " 'cls_hidden_dims': 256,\n", " 'batch_size': 128}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_parameters" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'accuracy': 0.567285785786404}, {'accuracy': {'accuracy': 0.0}})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "means, covariances = values\n", "means, covariances" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "linkText": "Export to plot.ly", "plotlyServerURL": "https://plot.ly", "showLink": false }, "data": [ { "hoverinfo": "none", "legendgroup": "", "line": { "width": 0 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ 54.896028666621696, 54.896028666621696, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395 ] }, { "fill": "tonexty", "fillcolor": "rgba(128,177,211,0.3)", "legendgroup": "mean", "line": { "color": "rgba(128,177,211,1)" }, "mode": "lines", "name": "mean", "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ 54.896028666621696, 54.896028666621696, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395 ] }, { "fill": "tonexty", "fillcolor": "rgba(128,177,211,0.3)", "hoverinfo": "none", "legendgroup": "", "line": { "width": 0 }, "mode": "lines", "showlegend": false, "type": "scatter", "x": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], "y": [ 54.896028666621696, 54.896028666621696, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.034234323550145, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395, 56.728578578640395 ] } ], "layout": { "showlegend": true, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Model performance vs. # of iterations" }, "xaxis": { "title": { "text": "Iteration" } }, "yaxis": { "title": { "text": "Classification Accuracy, %" } } } }, "text/html": [ "
\n", " \n", " \n", "
\n", " \n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])\n", "best_objective_plot = optimization_trace_single_method(\n", " y=np.maximum.accumulate(best_objectives, axis=1),\n", " title=\"Model performance vs. # of iterations\",\n", " ylabel=\"Classification Accuracy, %\",\n", ")\n", "render(best_objective_plot)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# for more customized optimization and plotting" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import numpy as np\n", "import argparse\n", "import datetime\n", "from ax import ParameterType, ChoiceParameter, RangeParameter, FixedParameter, SearchSpace, SimpleExperiment, modelbridge, models\n", "from ax.plot.contour import interact_contour, plot_contour\n", "from ax.plot.diagnostic import interact_cross_validation\n", "from ax.plot.scatter import interact_fitted, plot_objective_vs_constraints\n", "from ax.plot.slice import plot_slice\n", "from ax.modelbridge.cross_validation import cross_validate\n", "from ax.plot.trace import optimization_trace_single_method\n", "from plotly.offline import plot" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Running Sobol initialization trials...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 1/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 2/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 3/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 4/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 5/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 6/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 7/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 8/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 9/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 10/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 11/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 12/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 13/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 14/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 15/16...\n", "========================================\n", "\n", "\n", "Running GP+EI optimization trial 16/16...\n", "========================================\n", "\n" ] } ], "source": [ "opt_trials = 16\n", "init_trials = 5\n", "dset = 'Morgan'\n", "\n", "# Search space\n", "search_space = SearchSpace(parameters=[\n", " RangeParameter(\n", " name='LR', parameter_type=ParameterType.FLOAT, \n", " lower=1e-6, upper=1e-3, log_scale=False),\n", " RangeParameter(\n", " name='decay', parameter_type=ParameterType.FLOAT, \n", " lower=0, upper=0.2, log_scale=False),\n", " ChoiceParameter(\n", " name='batch_size', parameter_type=ParameterType.INT, \n", " values=[64, 128, 256]), \n", " ChoiceParameter(\n", " name='cls_hidden_dims', parameter_type=ParameterType.INT, \n", " values=[64, 128, 256, 512]), \n", "])\n", "\n", "# Create Experiment\n", "exp = SimpleExperiment(\n", " name = 'Morgan',\n", " search_space = search_space,\n", " evaluation_function = run_Morgan,\n", " objective_name = 'accuracy',\n", ")\n", "\n", "# Run the optimization and fit a GP on all data\n", "sobol = modelbridge.get_sobol(search_space=exp.search_space)\n", "print(f\"\\nRunning Sobol initialization trials...\\n{'='*40}\\n\")\n", "for _ in range(init_trials):\n", " exp.new_trial(generator_run=sobol.gen(1))\n", "\n", "for i in range(opt_trials):\n", " print(f\"\\nRunning GP+EI optimization trial {i+1}/{opt_trials}...\\n{'='*40}\\n\")\n", " gpei = modelbridge.get_GPEI(experiment=exp, data=exp.eval())\n", " exp.new_trial(generator_run=gpei.gen(1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# each plot will be opened in a html tab " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best arm:\n", " Arm(name='16_0', parameters={'LR': 0.00043294545493323936, 'decay': 0.08204705377971983, 'batch_size': 64, 'cls_hidden_dims': 256})\n" ] }, { "data": { "text/plain": [ "'Ax_output/Morgan/0510-145936/cv_plot.html'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output_dir = os.path.join('Ax_output', dset, datetime.datetime.now().strftime('%m%d-%H%M%S'))\n", "os.makedirs(output_dir)\n", "\n", "# Save all experiment parameters \n", "df = exp.eval().df\n", "df.to_csv(os.path.join(output_dir, 'exp_eval.csv'), index=False)\n", "\n", "# Save best parameter\n", "best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]\n", "exp_arm = {k:v.parameters for k, v in exp.arms_by_name.items()}\n", "exp_arm['best'] = best_arm_name\n", "print('Best arm:\\n', str(exp.arms_by_name[best_arm_name]))\n", "with open(os.path.join(output_dir, 'exp_arm.json'), 'w') as f: \n", " json.dump(exp_arm, f)\n", "\n", "# Contour Plot\n", "os.makedirs(os.path.join(output_dir, 'contour_plot'))\n", "for metric in ['accuracy']:\n", " contour_plot = interact_contour(model=gpei, metric_name=metric)\n", " plot(contour_plot.data, filename=os.path.join(output_dir, 'contour_plot', '{}.html'.format(metric)))\n", "\n", "# Slice Plot\n", "# show the metric outcome as a function of one parameter while fixing the others\n", "os.makedirs(os.path.join(output_dir, 'slice_plot'))\n", "for param in [\"LR\", \"decay\"]:\n", " slice_plot = plot_slice(gpei, param, \"accuracy\")\n", " plot(slice_plot.data, filename=os.path.join(output_dir, 'slice_plot', '{}.html'.format(param)))\n", "\n", "# Tile Plot\n", "# the effect of each arm\n", "tile_plot = interact_fitted(gpei, rel=False)\n", "plot(tile_plot.data, filename=os.path.join(output_dir, 'tile_plot.html'))\n", "\n", "# Cross Validation plot\n", "# splits the model's train data into train/test folds and makes out-of-sample predictions on the test folds.\n", "cv_results = cross_validate(gpei)\n", "cv_plot = interact_cross_validation(cv_results)\n", "plot(cv_plot.data, filename=os.path.join(output_dir, 'cv_plot.html'))\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# use the selected parameter." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def run_Morgan(fold_n, balanced, parameterization):\n", " \n", " train = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/train.csv')\n", " dev = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/dev.csv')\n", " test = pd.read_csv('./aicures_data/train_cv/fold_'+str(fold_n)+'/test.csv')\n", " \n", " if balanced:\n", " # oversample balanced training\n", " train = pd.concat([train[train.activity == 1].sample(n = len(train[train.activity == 0]), replace=True), train[train.activity == 0]]).sample(frac = 1).reset_index(drop = True)\n", " \n", " X_train = train.smiles.values\n", " y_train = train.activity.values\n", " X_dev = dev.smiles.values\n", " y_dev = dev.activity.values\n", " X_test = test.smiles.values\n", " y_test = test.activity.values\n", " \n", " drug_encoding = 'Morgan'\n", " train = data_process(X_drug = X_train, y = y_train, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)\n", "\n", " val = data_process(X_drug = X_dev, y = y_dev, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)\n", "\n", " test = data_process(X_drug = X_test, y = y_test, \n", " drug_encoding = drug_encoding,\n", " split_method='no_split', \n", " random_seed = 1)\n", " \n", " config = generate_config(drug_encoding = drug_encoding, \n", " cls_hidden_dims = [parameterization['cls_hidden_dims']], \n", " train_epoch = 10, \n", " LR = parameterization['LR'], \n", " batch_size = parameterization['batch_size'],\n", " decay = parameterization['decay']\n", " )\n", " \n", " model = property_pred.model_initialize(**config)\n", " model.train(train, val, test)\n", " \n", " scores = model.predict(test)\n", " \n", " return roc_auc_score(test.Label.values, scores), average_precision_score(test.Label.values, scores), scores, test.Label.values" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Drug Property Prediction Mode...\n", "in total: 3328 drugs\n", "encoding drug...\n", "unique drugs: 1694\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n", "Drug Property Prediction Mode...\n", "in total: 201 drugs\n", "encoding drug...\n", "unique drugs: 201\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n", "Drug Property Prediction Mode...\n", "in total: 202 drugs\n", "encoding drug...\n", "unique drugs: 202\n", "drug encoding finished...\n", "do not do train/test split on the data for already splitted data\n", "Let's use CPU/s!\n", "--- Data Preparation ---\n", "--- Go for Training ---\n", "Training at Epoch 1 iteration 0 with loss 0.69255. Total time 0.0 hours\n", "Validation at Epoch 1 , AUROC: 0.79187 , AUPRC: 0.53684 , F1: 0.36363\n", "Training at Epoch 2 iteration 0 with loss 0.55604. Total time 0.00027 hours\n", "Validation at Epoch 2 , AUROC: 0.78553 , AUPRC: 0.53354 , F1: 0.57142\n", "Training at Epoch 3 iteration 0 with loss 0.02210. Total time 0.00055 hours\n", "Validation at Epoch 3 , AUROC: 0.77030 , AUPRC: 0.31726 , F1: 0.0\n", "Training at Epoch 4 iteration 0 with loss 0.00274. Total time 0.00083 hours\n", "Validation at Epoch 4 , AUROC: 0.78426 , AUPRC: 0.45020 , F1: 0.4\n", "Training at Epoch 5 iteration 0 with loss 0.01539. Total time 0.00111 hours\n", "Validation at Epoch 5 , AUROC: 0.78807 , AUPRC: 0.53458 , F1: 0.57142\n", "Training at Epoch 6 iteration 0 with loss 0.00025. Total time 0.00166 hours\n", "Validation at Epoch 6 , AUROC: 0.76903 , AUPRC: 0.31680 , F1: 0.0\n", "Training at Epoch 7 iteration 0 with loss 0.00206. Total time 0.00194 hours\n", "Validation at Epoch 7 , AUROC: 0.78807 , AUPRC: 0.53560 , F1: 0.57142\n", "Training at Epoch 8 iteration 0 with loss 0.00017. Total time 0.00222 hours\n", "Validation at Epoch 8 , AUROC: 0.78680 , AUPRC: 0.45226 , F1: 0.4\n", "Training at Epoch 9 iteration 0 with loss 0.00035. Total time 0.0025 hours\n", "Validation at Epoch 9 , AUROC: 0.78426 , AUPRC: 0.45020 , F1: 0.4\n", "Training at Epoch 10 iteration 0 with loss 0.00020. Total time 0.00305 hours\n", "Validation at Epoch 10 , AUROC: 0.79568 , AUPRC: 0.55345 , F1: 0.57142\n", "--- Go for Testing ---\n", "Testing AUROC: 0.6626139817629179 , AUPRC: 0.539919932901533 , F1: 0.6363636363636364\n", "--- Training Finished ---\n", "predicting...\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "roc, prc, scores, labels = run_Morgan(1, True, best_parameters)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# We thank Chih-Ying Deng and Ax tutorials for the scripts on BO." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }