--- a +++ b/tutorials/mimic_classifier.ipynb @@ -0,0 +1,445 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.10" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3610jvsc74a57bd01452ce364e145c8938a76e90050576b8a2a4d70ee75de50f3361ff243fa2a5f7", + "display_name": "Python 3.6.10 64-bit ('EHRKit': conda)" + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "source": [ + "# Mimic coding \n", + "This notebook walks the user through the training and testing of a icd9 coding classifier, specifically built on and for the MIMIC-III dataset. The notebook utilizes the `mimic_icd9_coding` module, and the example uses a neural network classifier, although any other classifier may be used.\n", + "\n", + "The user must set the root directory for the EHRKit, and optionally the data directories." + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "ROOT_EHR_DIR = '<EHRKit Path>' # set your root EHRKit directory here (with the '/' at the end)\n", + "import sys\n", + "import os\n", + "sys.path.append(os.path.dirname(ROOT_EHR_DIR))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Set your mimic path here\n", + "# Put all of the individual mimic csv files in MIMIC_PATH, with the `/` at the end. These files should be all cap csv files, such as NOTEEVENTS.csv. Keep OUTPUT_DATA_PATH empty, the processed data will be deposited there.\n", + "OUTPUT_DATA_PATH = ROOT_EHR_DIR + 'data/output_data/'\n", + "MIMIC_PATH = ROOT_EHR_DIR + 'data/mimic_data/'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n", + "The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n", + "The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.\n", + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n", + "The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n", + "The savefig.jpeg_quality rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n", + "The keymap.all_axes rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n", + "The animation.avconv_path rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", + "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n", + "The animation.avconv_args rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n" + ] + } + ], + "source": [ + "from mimic_icd9_coding.coding_pipeline import codingPipeline\n", + "from mimic_icd9_coding.utils.mimic_data_preparation import run_mimic_prep" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the mimic label preparation, examine the top n most common labels, and create a dataset of labels and text\n", + "run_mimic_prep(output_folder = OUTPUT_DATA_PATH, mimic_data_path= MIMIC_PATH)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Building basic tfidf pipeline\n", + "Iteration 1, loss = 39.26339982\n", + "Iteration 2, loss = 27.15691598\n", + "Iteration 3, loss = 25.43022223\n", + "Iteration 4, loss = 24.39200925\n", + "Iteration 5, loss = 23.53421875\n", + "Iteration 6, loss = 22.71761089\n", + "Iteration 7, loss = 21.98377343\n", + "Iteration 8, loss = 21.40524055\n", + "Iteration 9, loss = 20.91691075\n", + "Iteration 10, loss = 20.46976534\n", + "Iteration 11, loss = 20.05602722\n", + "Iteration 12, loss = 19.67969932\n", + "Iteration 13, loss = 19.32472493\n", + "Iteration 14, loss = 18.98996470\n", + "Iteration 15, loss = 18.69047703\n", + "Iteration 16, loss = 18.41708954\n", + "Iteration 17, loss = 18.16937539\n", + "Iteration 18, loss = 17.94176193\n", + "Iteration 19, loss = 17.72638875\n", + "Iteration 20, loss = 17.52546337\n", + "Iteration 21, loss = 17.33147929\n", + "Iteration 22, loss = 17.14397179\n", + "Iteration 23, loss = 16.97036583\n", + "Iteration 24, loss = 16.80744903\n", + "Iteration 25, loss = 16.65087996\n", + "Iteration 26, loss = 16.50350788\n", + "Iteration 27, loss = 16.36437565\n", + "Iteration 28, loss = 16.22898538\n", + "Iteration 29, loss = 16.10017515\n", + "Iteration 30, loss = 15.97871846\n", + "Iteration 31, loss = 15.86154579\n", + "Iteration 32, loss = 15.75070126\n", + "Iteration 33, loss = 15.64012650\n", + "Iteration 34, loss = 15.53565693\n", + "Iteration 35, loss = 15.43139960\n", + "Iteration 36, loss = 15.32890083\n", + "Iteration 37, loss = 15.23351048\n", + "Iteration 38, loss = 15.14100144\n", + "Iteration 39, loss = 15.05027343\n", + "Iteration 40, loss = 14.96409722\n", + "Iteration 41, loss = 14.88040182\n", + "Iteration 42, loss = 14.80062628\n", + "Iteration 43, loss = 14.72397634\n", + "Iteration 44, loss = 14.64785816\n", + "Iteration 45, loss = 14.57387839\n", + "Iteration 46, loss = 14.50342585\n", + "Iteration 47, loss = 14.43510663\n", + "Iteration 48, loss = 14.36782499\n", + "Iteration 49, loss = 14.30226749\n", + "Iteration 50, loss = 14.24259876\n", + "Iteration 51, loss = 14.17783637\n", + "Iteration 52, loss = 14.12287286\n", + "Iteration 53, loss = 14.05997085\n", + "Iteration 54, loss = 14.00637036\n", + "Iteration 55, loss = 13.94991378\n", + "Iteration 56, loss = 13.89237458\n", + "Iteration 57, loss = 13.83448240\n", + "Iteration 58, loss = 13.77879710\n", + "Iteration 59, loss = 13.72299570\n", + "Iteration 60, loss = 13.66682543\n", + "Iteration 61, loss = 13.61414798\n", + "Iteration 62, loss = 13.56323821\n", + "Iteration 63, loss = 13.51129347\n", + "Iteration 64, loss = 13.46183629\n", + "Iteration 65, loss = 13.41255142\n", + "Iteration 66, loss = 13.36630654\n", + "Iteration 67, loss = 13.31819055\n", + "Iteration 68, loss = 13.27420876\n", + "Iteration 69, loss = 13.23153613\n", + "Iteration 70, loss = 13.18662444\n", + "Iteration 71, loss = 13.14404019\n", + "Iteration 72, loss = 13.10327418\n", + "Iteration 73, loss = 13.06489162\n", + "Iteration 74, loss = 13.02366019\n", + "Iteration 75, loss = 12.98480324\n", + "Iteration 76, loss = 12.94750608\n", + "Iteration 77, loss = 12.90973475\n", + "Iteration 78, loss = 12.86992567\n", + "Iteration 79, loss = 12.83378217\n", + "Iteration 80, loss = 12.80129023\n", + "Iteration 81, loss = 12.76208668\n", + "Iteration 82, loss = 12.72865387\n", + "Iteration 83, loss = 12.69202347\n", + "Iteration 84, loss = 12.65801739\n", + "Iteration 85, loss = 12.62539280\n", + "Iteration 86, loss = 12.59177266\n", + "Iteration 87, loss = 12.56355448\n", + "Iteration 88, loss = 12.52974378\n", + "Iteration 89, loss = 12.49803458\n", + "Iteration 90, loss = 12.47011190\n", + "Iteration 91, loss = 12.44098206\n", + "Iteration 92, loss = 12.41264979\n", + "Iteration 93, loss = 12.38165628\n", + "Iteration 94, loss = 12.35401116\n", + "Iteration 95, loss = 12.32752414\n", + "Iteration 96, loss = 12.29908133\n", + "Iteration 97, loss = 12.27293193\n", + "Iteration 98, loss = 12.24794822\n", + "Iteration 99, loss = 12.22328992\n", + "Iteration 100, loss = 12.19652259\n", + "/home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/sklearn/neural_network/_multilayer_perceptron.py:585: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.\n", + " % self.max_iter, ConvergenceWarning)\n", + "/home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n", + "\n", + "Classification Report\n", + "======================================================\n", + "\n", + " precision recall f1-score support\n", + "\n", + " 0 0.66 0.45 0.54 485\n", + " 1 0.40 0.19 0.26 1046\n", + " 2 0.69 0.60 0.64 1826\n", + " 3 0.49 0.25 0.33 2191\n", + " 4 0.83 0.65 0.73 608\n", + " 5 0.54 0.32 0.40 1470\n", + " 6 0.55 0.37 0.44 1327\n", + " 7 0.56 0.25 0.34 337\n", + " 8 0.75 0.55 0.63 748\n", + " 9 0.81 0.75 0.78 385\n", + " 10 0.43 0.16 0.23 486\n", + " 11 0.67 0.56 0.61 337\n", + " 12 0.66 0.49 0.56 538\n", + " 13 0.73 0.59 0.65 542\n", + " 14 0.80 0.70 0.75 1571\n", + " 15 0.88 0.77 0.82 4184\n", + " 16 0.24 0.05 0.08 397\n", + " 17 0.73 0.64 0.68 4229\n", + " 18 0.70 0.51 0.59 679\n", + " 19 0.27 0.02 0.04 354\n", + " 20 0.60 0.53 0.56 4370\n", + " 21 0.65 0.37 0.47 786\n", + " 22 0.40 0.14 0.21 654\n", + " 23 0.50 0.29 0.37 269\n", + " 24 0.58 0.42 0.49 3737\n", + " 25 0.50 0.21 0.29 591\n", + " 26 0.59 0.31 0.41 1175\n", + " 27 0.48 0.11 0.18 485\n", + " 28 0.70 0.50 0.58 308\n", + " 29 0.40 0.06 0.11 458\n", + " 30 0.65 0.44 0.52 560\n", + " 31 0.73 0.44 0.55 416\n", + " 32 0.51 0.30 0.38 846\n", + " 33 0.62 0.39 0.48 603\n", + " 34 0.50 0.35 0.41 1432\n", + " 35 0.71 0.48 0.57 675\n", + " 36 0.50 0.35 0.41 386\n", + " 37 0.42 0.13 0.20 342\n", + " 38 0.64 0.41 0.50 460\n", + " 39 0.57 0.34 0.43 955\n", + " 40 0.52 0.31 0.39 537\n", + " 41 0.59 0.29 0.39 399\n", + " 42 0.71 0.70 0.71 6165\n", + " 43 0.69 0.54 0.61 1873\n", + " 44 0.79 0.69 0.74 1606\n", + " 45 0.43 0.22 0.29 547\n", + " 46 0.34 0.06 0.10 333\n", + " 47 0.83 0.74 0.78 4202\n", + " 48 0.72 0.55 0.62 319\n", + " 49 0.47 0.23 0.31 769\n", + " 50 0.61 0.37 0.46 1620\n", + " 51 0.62 0.37 0.46 631\n", + " 52 0.64 0.30 0.41 431\n", + " 53 0.85 0.76 0.80 5036\n", + " 54 0.76 0.70 0.73 3901\n", + " 55 0.66 0.37 0.48 332\n", + " 56 0.72 0.51 0.59 460\n", + " 57 0.40 0.22 0.28 400\n", + " 58 0.57 0.28 0.38 515\n", + " 59 0.79 0.60 0.68 479\n", + " 60 0.35 0.09 0.14 516\n", + " 61 0.54 0.39 0.45 496\n", + " 62 0.69 0.62 0.65 304\n", + " 63 0.36 0.19 0.25 1459\n", + " 64 0.44 0.27 0.33 616\n", + " 65 0.66 0.41 0.51 399\n", + " 66 0.75 0.56 0.64 1044\n", + " 67 0.58 0.52 0.55 1046\n", + " 68 0.56 0.31 0.40 976\n", + " 69 0.72 0.60 0.65 3891\n", + " 70 0.64 0.44 0.52 333\n", + " 71 0.71 0.53 0.61 2201\n", + " 72 0.54 0.31 0.40 304\n", + " 73 0.51 0.22 0.31 334\n", + " 74 0.58 0.39 0.47 520\n", + " 75 0.67 0.34 0.45 411\n", + " 76 0.31 0.04 0.07 388\n", + " 77 0.51 0.34 0.41 296\n", + " 78 0.48 0.16 0.24 314\n", + " 79 0.83 0.63 0.72 886\n", + " 80 0.68 0.61 0.65 495\n", + " 81 0.66 0.46 0.54 326\n", + " 82 0.70 0.53 0.60 329\n", + " 83 0.80 0.50 0.61 473\n", + " 84 0.47 0.22 0.30 728\n", + " 85 0.75 0.66 0.70 3314\n", + " 86 0.75 0.60 0.66 1918\n", + " 87 0.21 0.01 0.03 296\n", + " 88 0.68 0.51 0.58 2144\n", + " 89 0.66 0.46 0.55 600\n", + " 90 0.58 0.37 0.45 481\n", + " 91 0.62 0.42 0.50 887\n", + " 92 0.44 0.14 0.21 444\n", + " 93 0.32 0.05 0.09 426\n", + " 94 0.33 0.01 0.02 323\n", + " 95 0.60 0.41 0.49 814\n", + " 96 0.96 0.95 0.96 898\n", + " 97 0.79 0.79 0.79 685\n", + " 98 0.91 0.88 0.89 759\n", + " 99 0.62 0.56 0.59 440\n", + " 100 0.49 0.20 0.29 1563\n", + " 101 0.42 0.11 0.17 428\n", + " 102 0.61 0.44 0.51 1498\n", + " 103 0.59 0.13 0.22 412\n", + " 104 0.35 0.08 0.14 733\n", + " 105 0.52 0.27 0.35 566\n", + " 106 0.59 0.35 0.44 607\n", + " 107 0.35 0.08 0.12 1137\n", + " 108 0.36 0.05 0.09 505\n", + " 109 0.78 0.56 0.65 175\n", + " 110 0.67 0.53 0.59 279\n", + " 111 0.78 0.61 0.68 299\n", + " 112 0.52 0.47 0.49 323\n", + " 113 0.54 0.27 0.36 256\n", + " 114 0.66 0.58 0.62 1654\n", + " 115 0.53 0.29 0.38 1247\n", + " 116 0.43 0.25 0.32 1473\n", + " 117 0.52 0.33 0.41 1450\n", + " 118 0.82 0.69 0.75 393\n", + " 119 0.11 0.01 0.01 596\n", + " 120 0.46 0.25 0.33 1876\n", + " 121 0.75 0.64 0.69 853\n", + " 122 0.40 0.13 0.19 1010\n", + " 123 0.43 0.07 0.12 306\n", + " 124 0.77 0.78 0.78 632\n", + " 125 0.58 0.38 0.46 1941\n", + " 126 0.37 0.13 0.19 1115\n", + " 127 0.34 0.08 0.13 1359\n", + " 128 0.79 0.82 0.81 774\n", + " 129 0.89 0.91 0.90 699\n", + " 130 0.88 0.85 0.86 305\n", + " 131 0.43 0.16 0.23 353\n", + " 132 0.43 0.16 0.24 659\n", + " 133 0.43 0.25 0.32 413\n", + " 134 0.60 0.42 0.49 2650\n", + " 135 0.40 0.15 0.22 640\n", + " 136 0.51 0.41 0.46 224\n", + " 137 0.55 0.36 0.44 2125\n", + " 138 0.24 0.09 0.13 309\n", + "\n", + " micro avg 0.67 0.47 0.55 141154\n", + " macro avg 0.58 0.39 0.45 141154\n", + "weighted avg 0.63 0.47 0.53 141154\n", + " samples avg 0.67 0.49 0.54 141154\n", + "\n", + "Pipeline complete\n", + "/home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n" + ] + } + ], + "source": [ + "# Create and train mimic data preprocessing and model, default is a random forest model\n", + "# But any model can be passed in (SKlearn or other model which fits and predicts in the same manner)\n", + "# Note that the resulting classification report identifies the metrics for each of the 138 different icd9 codes that this model investigates.\n", + "print(\"Building basic tfidf pipeline\")\n", + "from sklearn.neural_network import MLPClassifier\n", + "clf = MLPClassifier(hidden_layer_sizes=(100,), max_iter=100, verbose=True)\n", + "# Switch max_iter to 100 for better results, but to run for the first time 10 is good\n", + "my_mimic_pipeline = codingPipeline(verbose=True, model=clf, data_path = OUTPUT_DATA_PATH)\n", + "print(\"Pipeline complete\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Auroc is 0.69\n" + ] + } + ], + "source": [ + "# Let's check out the auroc\n", + "auroc = my_mimic_pipeline.auroc\n", + "print(\"Auroc is {:.2f}\".format(auroc))" + ] + }, + { + "source": [ + "## Let's test out the model on a specific note!" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Here we load the data into the pipeline, this function simply saves the data, we don't want to save the data automatically because it uses more memory\n", + "my_mimic_pipeline.load_data()\n", + "df = my_mimic_pipeline.data" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Predicted ICD9 codes: [('424', '441', '785')]\nTrue ICD9 codes: ['424', '441', '785']\n" + ] + } + ], + "source": [ + "# We run the algorithm and see that at least for this example our model is pretty good\n", + "pred = my_mimic_pipeline.predict(df['TEXT'].iloc[10])\n", + "true = df['TARGET'].iloc[10]\n", + "print(\"Predicted ICD9 codes: {}\".format(pred))\n", + "print(\"True ICD9 codes: {}\".format(true))" + ] + } + ] +} \ No newline at end of file