[2d4573]: / tutorials / mimic_classifier.ipynb

Download this file

445 lines (445 with data), 22.1 kB

{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3610jvsc74a57bd01452ce364e145c8938a76e90050576b8a2a4d70ee75de50f3361ff243fa2a5f7",
   "display_name": "Python 3.6.10 64-bit ('EHRKit': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# Mimic coding \n",
    "This notebook walks the user through the training and testing of a icd9 coding classifier, specifically built on and for the MIMIC-III dataset. The notebook utilizes the `mimic_icd9_coding` module, and the example uses a neural network classifier, although any other classifier may be used.\n",
    "\n",
    "The user must set the root directory for the EHRKit, and optionally the data directories."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT_EHR_DIR = '<EHRKit Path>' # set your root EHRKit directory here (with the '/' at the end)\n",
    "import sys\n",
    "import os\n",
    "sys.path.append(os.path.dirname(ROOT_EHR_DIR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set your mimic path here\n",
    "# Put all of the individual mimic csv files in MIMIC_PATH, with the `/` at the end. These files should be all cap csv files, such as NOTEEVENTS.csv. Keep OUTPUT_DATA_PATH empty, the processed data will be deposited there.\n",
    "OUTPUT_DATA_PATH = ROOT_EHR_DIR + 'data/output_data/'\n",
    "MIMIC_PATH = ROOT_EHR_DIR + 'data/mimic_data/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
      "The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
      "The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.\n",
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
      "The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
      "The savefig.jpeg_quality rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
      "The keymap.all_axes rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
      "The animation.avconv_path rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
      "In /home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
      "The animation.avconv_args rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n"
     ]
    }
   ],
   "source": [
    "from mimic_icd9_coding.coding_pipeline import codingPipeline\n",
    "from mimic_icd9_coding.utils.mimic_data_preparation import run_mimic_prep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the mimic label preparation, examine the top n most common labels, and create a dataset of labels and text\n",
    "run_mimic_prep(output_folder = OUTPUT_DATA_PATH, mimic_data_path= MIMIC_PATH)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Building basic tfidf pipeline\n",
      "Iteration 1, loss = 39.26339982\n",
      "Iteration 2, loss = 27.15691598\n",
      "Iteration 3, loss = 25.43022223\n",
      "Iteration 4, loss = 24.39200925\n",
      "Iteration 5, loss = 23.53421875\n",
      "Iteration 6, loss = 22.71761089\n",
      "Iteration 7, loss = 21.98377343\n",
      "Iteration 8, loss = 21.40524055\n",
      "Iteration 9, loss = 20.91691075\n",
      "Iteration 10, loss = 20.46976534\n",
      "Iteration 11, loss = 20.05602722\n",
      "Iteration 12, loss = 19.67969932\n",
      "Iteration 13, loss = 19.32472493\n",
      "Iteration 14, loss = 18.98996470\n",
      "Iteration 15, loss = 18.69047703\n",
      "Iteration 16, loss = 18.41708954\n",
      "Iteration 17, loss = 18.16937539\n",
      "Iteration 18, loss = 17.94176193\n",
      "Iteration 19, loss = 17.72638875\n",
      "Iteration 20, loss = 17.52546337\n",
      "Iteration 21, loss = 17.33147929\n",
      "Iteration 22, loss = 17.14397179\n",
      "Iteration 23, loss = 16.97036583\n",
      "Iteration 24, loss = 16.80744903\n",
      "Iteration 25, loss = 16.65087996\n",
      "Iteration 26, loss = 16.50350788\n",
      "Iteration 27, loss = 16.36437565\n",
      "Iteration 28, loss = 16.22898538\n",
      "Iteration 29, loss = 16.10017515\n",
      "Iteration 30, loss = 15.97871846\n",
      "Iteration 31, loss = 15.86154579\n",
      "Iteration 32, loss = 15.75070126\n",
      "Iteration 33, loss = 15.64012650\n",
      "Iteration 34, loss = 15.53565693\n",
      "Iteration 35, loss = 15.43139960\n",
      "Iteration 36, loss = 15.32890083\n",
      "Iteration 37, loss = 15.23351048\n",
      "Iteration 38, loss = 15.14100144\n",
      "Iteration 39, loss = 15.05027343\n",
      "Iteration 40, loss = 14.96409722\n",
      "Iteration 41, loss = 14.88040182\n",
      "Iteration 42, loss = 14.80062628\n",
      "Iteration 43, loss = 14.72397634\n",
      "Iteration 44, loss = 14.64785816\n",
      "Iteration 45, loss = 14.57387839\n",
      "Iteration 46, loss = 14.50342585\n",
      "Iteration 47, loss = 14.43510663\n",
      "Iteration 48, loss = 14.36782499\n",
      "Iteration 49, loss = 14.30226749\n",
      "Iteration 50, loss = 14.24259876\n",
      "Iteration 51, loss = 14.17783637\n",
      "Iteration 52, loss = 14.12287286\n",
      "Iteration 53, loss = 14.05997085\n",
      "Iteration 54, loss = 14.00637036\n",
      "Iteration 55, loss = 13.94991378\n",
      "Iteration 56, loss = 13.89237458\n",
      "Iteration 57, loss = 13.83448240\n",
      "Iteration 58, loss = 13.77879710\n",
      "Iteration 59, loss = 13.72299570\n",
      "Iteration 60, loss = 13.66682543\n",
      "Iteration 61, loss = 13.61414798\n",
      "Iteration 62, loss = 13.56323821\n",
      "Iteration 63, loss = 13.51129347\n",
      "Iteration 64, loss = 13.46183629\n",
      "Iteration 65, loss = 13.41255142\n",
      "Iteration 66, loss = 13.36630654\n",
      "Iteration 67, loss = 13.31819055\n",
      "Iteration 68, loss = 13.27420876\n",
      "Iteration 69, loss = 13.23153613\n",
      "Iteration 70, loss = 13.18662444\n",
      "Iteration 71, loss = 13.14404019\n",
      "Iteration 72, loss = 13.10327418\n",
      "Iteration 73, loss = 13.06489162\n",
      "Iteration 74, loss = 13.02366019\n",
      "Iteration 75, loss = 12.98480324\n",
      "Iteration 76, loss = 12.94750608\n",
      "Iteration 77, loss = 12.90973475\n",
      "Iteration 78, loss = 12.86992567\n",
      "Iteration 79, loss = 12.83378217\n",
      "Iteration 80, loss = 12.80129023\n",
      "Iteration 81, loss = 12.76208668\n",
      "Iteration 82, loss = 12.72865387\n",
      "Iteration 83, loss = 12.69202347\n",
      "Iteration 84, loss = 12.65801739\n",
      "Iteration 85, loss = 12.62539280\n",
      "Iteration 86, loss = 12.59177266\n",
      "Iteration 87, loss = 12.56355448\n",
      "Iteration 88, loss = 12.52974378\n",
      "Iteration 89, loss = 12.49803458\n",
      "Iteration 90, loss = 12.47011190\n",
      "Iteration 91, loss = 12.44098206\n",
      "Iteration 92, loss = 12.41264979\n",
      "Iteration 93, loss = 12.38165628\n",
      "Iteration 94, loss = 12.35401116\n",
      "Iteration 95, loss = 12.32752414\n",
      "Iteration 96, loss = 12.29908133\n",
      "Iteration 97, loss = 12.27293193\n",
      "Iteration 98, loss = 12.24794822\n",
      "Iteration 99, loss = 12.22328992\n",
      "Iteration 100, loss = 12.19652259\n",
      "/home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/sklearn/neural_network/_multilayer_perceptron.py:585: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.\n",
      "  % self.max_iter, ConvergenceWarning)\n",
      "/home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "\n",
      "Classification Report\n",
      "======================================================\n",
      "\n",
      "               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.66      0.45      0.54       485\n",
      "           1       0.40      0.19      0.26      1046\n",
      "           2       0.69      0.60      0.64      1826\n",
      "           3       0.49      0.25      0.33      2191\n",
      "           4       0.83      0.65      0.73       608\n",
      "           5       0.54      0.32      0.40      1470\n",
      "           6       0.55      0.37      0.44      1327\n",
      "           7       0.56      0.25      0.34       337\n",
      "           8       0.75      0.55      0.63       748\n",
      "           9       0.81      0.75      0.78       385\n",
      "          10       0.43      0.16      0.23       486\n",
      "          11       0.67      0.56      0.61       337\n",
      "          12       0.66      0.49      0.56       538\n",
      "          13       0.73      0.59      0.65       542\n",
      "          14       0.80      0.70      0.75      1571\n",
      "          15       0.88      0.77      0.82      4184\n",
      "          16       0.24      0.05      0.08       397\n",
      "          17       0.73      0.64      0.68      4229\n",
      "          18       0.70      0.51      0.59       679\n",
      "          19       0.27      0.02      0.04       354\n",
      "          20       0.60      0.53      0.56      4370\n",
      "          21       0.65      0.37      0.47       786\n",
      "          22       0.40      0.14      0.21       654\n",
      "          23       0.50      0.29      0.37       269\n",
      "          24       0.58      0.42      0.49      3737\n",
      "          25       0.50      0.21      0.29       591\n",
      "          26       0.59      0.31      0.41      1175\n",
      "          27       0.48      0.11      0.18       485\n",
      "          28       0.70      0.50      0.58       308\n",
      "          29       0.40      0.06      0.11       458\n",
      "          30       0.65      0.44      0.52       560\n",
      "          31       0.73      0.44      0.55       416\n",
      "          32       0.51      0.30      0.38       846\n",
      "          33       0.62      0.39      0.48       603\n",
      "          34       0.50      0.35      0.41      1432\n",
      "          35       0.71      0.48      0.57       675\n",
      "          36       0.50      0.35      0.41       386\n",
      "          37       0.42      0.13      0.20       342\n",
      "          38       0.64      0.41      0.50       460\n",
      "          39       0.57      0.34      0.43       955\n",
      "          40       0.52      0.31      0.39       537\n",
      "          41       0.59      0.29      0.39       399\n",
      "          42       0.71      0.70      0.71      6165\n",
      "          43       0.69      0.54      0.61      1873\n",
      "          44       0.79      0.69      0.74      1606\n",
      "          45       0.43      0.22      0.29       547\n",
      "          46       0.34      0.06      0.10       333\n",
      "          47       0.83      0.74      0.78      4202\n",
      "          48       0.72      0.55      0.62       319\n",
      "          49       0.47      0.23      0.31       769\n",
      "          50       0.61      0.37      0.46      1620\n",
      "          51       0.62      0.37      0.46       631\n",
      "          52       0.64      0.30      0.41       431\n",
      "          53       0.85      0.76      0.80      5036\n",
      "          54       0.76      0.70      0.73      3901\n",
      "          55       0.66      0.37      0.48       332\n",
      "          56       0.72      0.51      0.59       460\n",
      "          57       0.40      0.22      0.28       400\n",
      "          58       0.57      0.28      0.38       515\n",
      "          59       0.79      0.60      0.68       479\n",
      "          60       0.35      0.09      0.14       516\n",
      "          61       0.54      0.39      0.45       496\n",
      "          62       0.69      0.62      0.65       304\n",
      "          63       0.36      0.19      0.25      1459\n",
      "          64       0.44      0.27      0.33       616\n",
      "          65       0.66      0.41      0.51       399\n",
      "          66       0.75      0.56      0.64      1044\n",
      "          67       0.58      0.52      0.55      1046\n",
      "          68       0.56      0.31      0.40       976\n",
      "          69       0.72      0.60      0.65      3891\n",
      "          70       0.64      0.44      0.52       333\n",
      "          71       0.71      0.53      0.61      2201\n",
      "          72       0.54      0.31      0.40       304\n",
      "          73       0.51      0.22      0.31       334\n",
      "          74       0.58      0.39      0.47       520\n",
      "          75       0.67      0.34      0.45       411\n",
      "          76       0.31      0.04      0.07       388\n",
      "          77       0.51      0.34      0.41       296\n",
      "          78       0.48      0.16      0.24       314\n",
      "          79       0.83      0.63      0.72       886\n",
      "          80       0.68      0.61      0.65       495\n",
      "          81       0.66      0.46      0.54       326\n",
      "          82       0.70      0.53      0.60       329\n",
      "          83       0.80      0.50      0.61       473\n",
      "          84       0.47      0.22      0.30       728\n",
      "          85       0.75      0.66      0.70      3314\n",
      "          86       0.75      0.60      0.66      1918\n",
      "          87       0.21      0.01      0.03       296\n",
      "          88       0.68      0.51      0.58      2144\n",
      "          89       0.66      0.46      0.55       600\n",
      "          90       0.58      0.37      0.45       481\n",
      "          91       0.62      0.42      0.50       887\n",
      "          92       0.44      0.14      0.21       444\n",
      "          93       0.32      0.05      0.09       426\n",
      "          94       0.33      0.01      0.02       323\n",
      "          95       0.60      0.41      0.49       814\n",
      "          96       0.96      0.95      0.96       898\n",
      "          97       0.79      0.79      0.79       685\n",
      "          98       0.91      0.88      0.89       759\n",
      "          99       0.62      0.56      0.59       440\n",
      "         100       0.49      0.20      0.29      1563\n",
      "         101       0.42      0.11      0.17       428\n",
      "         102       0.61      0.44      0.51      1498\n",
      "         103       0.59      0.13      0.22       412\n",
      "         104       0.35      0.08      0.14       733\n",
      "         105       0.52      0.27      0.35       566\n",
      "         106       0.59      0.35      0.44       607\n",
      "         107       0.35      0.08      0.12      1137\n",
      "         108       0.36      0.05      0.09       505\n",
      "         109       0.78      0.56      0.65       175\n",
      "         110       0.67      0.53      0.59       279\n",
      "         111       0.78      0.61      0.68       299\n",
      "         112       0.52      0.47      0.49       323\n",
      "         113       0.54      0.27      0.36       256\n",
      "         114       0.66      0.58      0.62      1654\n",
      "         115       0.53      0.29      0.38      1247\n",
      "         116       0.43      0.25      0.32      1473\n",
      "         117       0.52      0.33      0.41      1450\n",
      "         118       0.82      0.69      0.75       393\n",
      "         119       0.11      0.01      0.01       596\n",
      "         120       0.46      0.25      0.33      1876\n",
      "         121       0.75      0.64      0.69       853\n",
      "         122       0.40      0.13      0.19      1010\n",
      "         123       0.43      0.07      0.12       306\n",
      "         124       0.77      0.78      0.78       632\n",
      "         125       0.58      0.38      0.46      1941\n",
      "         126       0.37      0.13      0.19      1115\n",
      "         127       0.34      0.08      0.13      1359\n",
      "         128       0.79      0.82      0.81       774\n",
      "         129       0.89      0.91      0.90       699\n",
      "         130       0.88      0.85      0.86       305\n",
      "         131       0.43      0.16      0.23       353\n",
      "         132       0.43      0.16      0.24       659\n",
      "         133       0.43      0.25      0.32       413\n",
      "         134       0.60      0.42      0.49      2650\n",
      "         135       0.40      0.15      0.22       640\n",
      "         136       0.51      0.41      0.46       224\n",
      "         137       0.55      0.36      0.44      2125\n",
      "         138       0.24      0.09      0.13       309\n",
      "\n",
      "   micro avg       0.67      0.47      0.55    141154\n",
      "   macro avg       0.58      0.39      0.45    141154\n",
      "weighted avg       0.63      0.47      0.53    141154\n",
      " samples avg       0.67      0.49      0.54    141154\n",
      "\n",
      "Pipeline complete\n",
      "/home/lily/br384/anaconda3/envs/EHRKit/lib/python3.6/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "# Create and train mimic data preprocessing and model, default is a random forest model\n",
    "# But any model can be passed in (SKlearn or other model which fits and predicts in the same manner)\n",
    "# Note that the resulting classification report identifies the metrics for each of the 138 different icd9 codes that this model investigates.\n",
    "print(\"Building basic tfidf pipeline\")\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "clf = MLPClassifier(hidden_layer_sizes=(100,), max_iter=100, verbose=True)\n",
    "# Switch max_iter to 100 for better results, but to run for the first time 10 is good\n",
    "my_mimic_pipeline = codingPipeline(verbose=True, model=clf, data_path = OUTPUT_DATA_PATH)\n",
    "print(\"Pipeline complete\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Auroc is 0.69\n"
     ]
    }
   ],
   "source": [
    "# Let's check out the auroc\n",
    "auroc = my_mimic_pipeline.auroc\n",
    "print(\"Auroc is {:.2f}\".format(auroc))"
   ]
  },
  {
   "source": [
    "## Let's test out the model on a specific note!"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we load the data into the pipeline, this function simply saves the data, we don't want to save the data automatically because it uses more memory\n",
    "my_mimic_pipeline.load_data()\n",
    "df = my_mimic_pipeline.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Predicted ICD9 codes: [('424', '441', '785')]\nTrue ICD9 codes: ['424', '441', '785']\n"
     ]
    }
   ],
   "source": [
    "# We run the algorithm and see that at least for this example our model is pretty good\n",
    "pred = my_mimic_pipeline.predict(df['TEXT'].iloc[10])\n",
    "true = df['TARGET'].iloc[10]\n",
    "print(\"Predicted ICD9 codes: {}\".format(pred))\n",
    "print(\"True ICD9 codes: {}\".format(true))"
   ]
  }
 ]
}