--- a +++ b/Clinical Deterioration Prediction Model - Selection of Ensemble Algorithms .ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "hide": true + }, + "source": [ + "import numpy as np\n", + "import os\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "# ^^^ pyforest auto-imports - don't write above this line\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import os\n", + "import sklearn\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "import os\n", + "\n", + "$$\n", + "\\renewcommand{\\like}{{\\cal L}}\n", + "\\renewcommand{\\loglike}{{\\ell}}\n", + "\\renewcommand{\\err}{{\\cal E}}\n", + "\\renewcommand{\\dat}{{\\cal D}}\n", + "\\renewcommand{\\hyp}{{\\cal H}}\n", + "\\renewcommand{\\Ex}[2]{E_{#1}[#2]}\n", + "\\renewcommand{\\x}{{\\mathbf x}}\n", + "\\renewcommand{\\v}[1]{{\\mathbf #1}}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Clinical Deterioration Prediction Model - Selection of Ensemble Algorithms " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data\n", + "\n", + "The final dataset used for the inferential statistics project includes unique ICU admission of 46,234 patients’ demographic (age), vital (blood pressure, heart rate, body temperature, and Glasgow Comma Scale), underlying conditions (HIV, metastatic cancer, and hematologic malignancy), admission type (scheduled surgical, medical, or unscheduled surgical), renal (urinary output, and Blood Urea Nitrogen), and others (serum bicarbonate level, sodium level, potassium level, and bilirubin level) data. This dataset is build based on the commonly used mortality prediction tool, Simplified Acute Physiology Score II (SAPSII). " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'C:\\\\Users\\\\abebu\\\\Dropbox\\\\Data Science\\\\Projects\\\\Capstone Project 1\\\\Potential Projects\\\\9. MIMIC\\\\Machine Learning\\\\Clinical-Deterioration-Prediction-Model---Bayesian-Linear-Regression'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "os.getcwd()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os\\nimport pandas as pd'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>SUBJECT_ID</th>\n", + " <th>HADM_ID</th>\n", + " <th>ICUSTAY_ID</th>\n", + " <th>los</th>\n", + " <th>hdeath</th>\n", + " <th>death</th>\n", + " <th>admission</th>\n", + " <th>ud</th>\n", + " <th>bun</th>\n", + " <th>Bicarbonate</th>\n", + " <th>...</th>\n", + " <th>WBC_3.0</th>\n", + " <th>hr_0.0</th>\n", + " <th>hr_2.0</th>\n", + " <th>hr_4.0</th>\n", + " <th>hr_7.0</th>\n", + " <th>hr_11.0</th>\n", + " <th>bp_0.0</th>\n", + " <th>bp_2.0</th>\n", + " <th>bp_5.0</th>\n", + " <th>bp_13.0</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <td>0</td>\n", + " <td>268</td>\n", + " <td>110404</td>\n", + " <td>280836</td>\n", + " <td>3.2490</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>269</td>\n", + " <td>106296</td>\n", + " <td>206613</td>\n", + " <td>3.2788</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>17.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>270</td>\n", + " <td>188028</td>\n", + " <td>220345</td>\n", + " <td>2.8939</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>271</td>\n", + " <td>173727</td>\n", + " <td>249196</td>\n", + " <td>2.0600</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>272</td>\n", + " <td>164716</td>\n", + " <td>210407</td>\n", + " <td>1.6202</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 33 columns</p>\n", + "</div>" + ], + "text/plain": [ + " SUBJECT_ID HADM_ID ICUSTAY_ID los hdeath death admission ud \\\n", + "0 268 110404 280836 3.2490 1 1 8 0.0 \n", + "1 269 106296 206613 3.2788 0 0 8 17.0 \n", + "2 270 188028 220345 2.8939 0 0 0 0.0 \n", + "3 271 173727 249196 2.0600 0 0 8 0.0 \n", + "4 272 164716 210407 1.6202 0 0 8 0.0 \n", + "\n", + " bun Bicarbonate ... WBC_3.0 hr_0.0 hr_2.0 hr_4.0 hr_7.0 hr_11.0 \\\n", + "0 6.0 0.0 ... 0 0 0 0 0 1 \n", + "1 0.0 0.0 ... 0 1 0 0 0 0 \n", + "2 0.0 0.0 ... 0 0 0 0 0 1 \n", + "3 0.0 0.0 ... 0 1 0 0 0 0 \n", + "4 0.0 0.0 ... 0 1 0 0 0 0 \n", + "\n", + " bp_0.0 bp_2.0 bp_5.0 bp_13.0 \n", + "0 0 0 0 1 \n", + "1 0 0 1 0 \n", + "2 0 0 0 1 \n", + "3 1 0 0 0 \n", + "4 0 0 1 0 \n", + "\n", + "[5 rows x 33 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "os.chdir(\"C://Users/abebu/Google Drive/mimic-iii-clinical-database-1.4\")\n", + "saps = pd.read_csv(\"saps_ts.csv\", header=0, index_col=0)\n", + "saps.head()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>hdeath</th>\n", + " <th>admission</th>\n", + " <th>ud</th>\n", + " <th>bun</th>\n", + " <th>Bicarbonate</th>\n", + " <th>ventilation</th>\n", + " <th>Temp</th>\n", + " <th>Bilirubin</th>\n", + " <th>gcs</th>\n", + " <th>AGE</th>\n", + " <th>...</th>\n", + " <th>WBC_3.0</th>\n", + " <th>hr_0.0</th>\n", + " <th>hr_2.0</th>\n", + " <th>hr_4.0</th>\n", + " <th>hr_7.0</th>\n", + " <th>hr_11.0</th>\n", + " <th>bp_0.0</th>\n", + " <th>bp_2.0</th>\n", + " <th>bp_5.0</th>\n", + " <th>bp_13.0</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>26.0</td>\n", + " <td>12.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>17.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>7.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>18.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>7.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>12.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <td>61112</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>26.0</td>\n", + " <td>16.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>61113</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>5.0</td>\n", + " <td>18.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>61114</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>7.0</td>\n", + " <td>7.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>61115</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>26.0</td>\n", + " <td>12.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>61116</td>\n", + " <td>0</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>11.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>26.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>61117 rows × 28 columns</p>\n", + "</div>" + ], + "text/plain": [ + " hdeath admission ud bun Bicarbonate ventilation Temp \\\n", + "0 1 8 0.0 6.0 0.0 6.0 0.0 \n", + "1 0 8 17.0 0.0 0.0 0.0 0.0 \n", + "2 0 0 0.0 0.0 0.0 0.0 3.0 \n", + "3 0 8 0.0 0.0 0.0 6.0 3.0 \n", + "4 0 8 0.0 0.0 0.0 0.0 3.0 \n", + "... ... ... ... ... ... ... ... \n", + "61112 0 8 0.0 0.0 0.0 6.0 3.0 \n", + "61113 0 8 0.0 0.0 0.0 0.0 3.0 \n", + "61114 0 0 0.0 0.0 0.0 6.0 3.0 \n", + "61115 0 0 0.0 0.0 0.0 6.0 3.0 \n", + "61116 0 8 0.0 0.0 0.0 11.0 3.0 \n", + "\n", + " Bilirubin gcs AGE ... WBC_3.0 hr_0.0 hr_2.0 hr_4.0 hr_7.0 \\\n", + "0 0.0 26.0 12.0 ... 0 0 0 0 0 \n", + "1 0.0 0.0 7.0 ... 0 1 0 0 0 \n", + "2 0.0 0.0 18.0 ... 0 0 0 0 0 \n", + "3 0.0 0.0 7.0 ... 0 1 0 0 0 \n", + "4 0.0 0.0 12.0 ... 0 1 0 0 0 \n", + "... ... ... ... ... ... ... ... ... ... \n", + "61112 0.0 26.0 16.0 ... 0 0 1 0 0 \n", + "61113 0.0 5.0 18.0 ... 0 0 1 0 0 \n", + "61114 0.0 7.0 7.0 ... 0 1 0 0 0 \n", + "61115 0.0 26.0 12.0 ... 0 1 0 0 0 \n", + "61116 0.0 26.0 0.0 ... 0 1 0 0 0 \n", + "\n", + " hr_11.0 bp_0.0 bp_2.0 bp_5.0 bp_13.0 \n", + "0 1 0 0 0 1 \n", + "1 0 0 0 1 0 \n", + "2 1 0 0 0 1 \n", + "3 0 1 0 0 0 \n", + "4 0 0 0 1 0 \n", + "... ... ... ... ... ... \n", + "61112 0 0 0 0 1 \n", + "61113 0 1 0 0 0 \n", + "61114 0 0 0 1 0 \n", + "61115 0 0 0 0 1 \n", + "61116 0 0 0 1 0 \n", + "\n", + "[61117 rows x 28 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "saps_e=saps.drop(['los','death','SUBJECT_ID','HADM_ID', 'ICUSTAY_ID'], axis=1)\n", + "saps_e" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# Pandas and numpy for data manipulation\n", + "import pandas as pd\n", + "import numpy as np\n", + "np.random.seed(42)\n", + "\n", + " \n", + "# Matplotlib and seaborn for plotting\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "import matplotlib\n", + "matplotlib.rcParams['font.size'] = 16\n", + "matplotlib.rcParams['figure.figsize'] = (9, 9)\n", + "\n", + "import seaborn as sns\n", + "\n", + "# Scipy helper functions\n", + "from scipy.stats import percentileofscore\n", + "from scipy import stats" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Standard ensembel ML Models for comparison\n", + "\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.ensemble import ExtraTreesClassifier\n", + "from sklearn.ensemble import GradientBoostingClassifier\n", + "\n", + "# Splitting data into training/testing\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import MinMaxScaler\n", + "\n", + "# Metrics\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, median_absolute_error\n", + "\n", + "# Distributions\n", + "import scipy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# PyMC3 for Bayesian Inference\n", + "import pymc3 as pm" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# grade and returns training and testing datasets\n", + "def format_data(df):\n", + " # Target is hospital death\n", + " labels = saps_e['hdeath']\n", + " \n", + " # Drop target (hdeath) from features\n", + " df = df.drop(columns=['hdeath'])\n", + " \n", + " # Split into training/testing sets with 30% split\n", + " X_train, X_test, y_train, y_test = train_test_split(df, labels, \n", + " test_size = 0.30,\n", + " random_state=42)\n", + " \n", + " return X_train, X_test, y_train, y_test\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os\\nimport pandas as pd\\nfrom sklearn.model_selection import train_test_split'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>admission</th>\n", + " <th>ud</th>\n", + " <th>bun</th>\n", + " <th>Bicarbonate</th>\n", + " <th>ventilation</th>\n", + " <th>Temp</th>\n", + " <th>Bilirubin</th>\n", + " <th>gcs</th>\n", + " <th>AGE</th>\n", + " <th>UO</th>\n", + " <th>...</th>\n", + " <th>WBC_3.0</th>\n", + " <th>hr_0.0</th>\n", + " <th>hr_2.0</th>\n", + " <th>hr_4.0</th>\n", + " <th>hr_7.0</th>\n", + " <th>hr_11.0</th>\n", + " <th>bp_0.0</th>\n", + " <th>bp_2.0</th>\n", + " <th>bp_5.0</th>\n", + " <th>bp_13.0</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <td>53545</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>15.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>51512</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>18.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>23837</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>12.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <td>21929</td>\n", + " <td>0</td>\n", + " <td>0.0</td>\n", + " <td>10.0</td>\n", + " <td>0.0</td>\n", + " <td>6.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>5.0</td>\n", + " <td>18.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <td>57339</td>\n", + " <td>8</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>3.0</td>\n", + " <td>0.0</td>\n", + " <td>0.0</td>\n", + " <td>12.0</td>\n", + " <td>0.0</td>\n", + " <td>...</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 27 columns</p>\n", + "</div>" + ], + "text/plain": [ + " admission ud bun Bicarbonate ventilation Temp Bilirubin gcs \\\n", + "53545 8 0.0 0.0 0.0 0.0 3.0 0.0 0.0 \n", + "51512 8 0.0 6.0 0.0 0.0 3.0 0.0 0.0 \n", + "23837 8 0.0 0.0 0.0 0.0 3.0 0.0 0.0 \n", + "21929 0 0.0 10.0 0.0 6.0 3.0 0.0 5.0 \n", + "57339 8 0.0 0.0 0.0 0.0 3.0 0.0 0.0 \n", + "\n", + " AGE UO ... WBC_3.0 hr_0.0 hr_2.0 hr_4.0 hr_7.0 hr_11.0 \\\n", + "53545 15.0 0.0 ... 0 1 0 0 0 0 \n", + "51512 18.0 0.0 ... 0 0 1 0 0 0 \n", + "23837 12.0 0.0 ... 0 0 1 0 0 0 \n", + "21929 18.0 0.0 ... 0 1 0 0 0 0 \n", + "57339 12.0 0.0 ... 0 0 1 0 0 0 \n", + "\n", + " bp_0.0 bp_2.0 bp_5.0 bp_13.0 \n", + "53545 1 0 0 0 \n", + "51512 0 0 1 0 \n", + "23837 0 0 1 0 \n", + "21929 0 0 0 1 \n", + "57339 0 0 1 0 \n", + "\n", + "[5 rows x 27 columns]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_train, X_test, y_train, y_test = format_data(saps_e)\n", + "X_train.head()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(42781, 27)\n", + "(18336, 27)\n" + ] + } + ], + "source": [ + "print(X_train.shape)\n", + "print(X_test.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Standard Ensemble Machine Learning Models" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluate several ml models by training on training set and testing on testing set\n", + "def evaluate(X_train, X_test, y_train, y_test):\n", + " #Names of models\n", + " model_name_list = ['Random Forest', 'Extra Trees',\n", + " 'Gradient Boosted']\n", + " model = np.arange(1, 4)\n", + " train_accuracy = np.empty(len(model))\n", + " test_accuracy = np.empty(len(model))\n", + " # Instantiate the models\n", + " model1 = RandomForestClassifier(n_estimators=100)\n", + " model2 = ExtraTreesClassifier(n_estimators=100)\n", + " model3 = GradientBoostingClassifier(n_estimators=100)\n", + " results =pd.DataFrame(columns=['train_accuracy', 'test_accuracy']) \n", + " \n", + " # Train and predict with each model\n", + " for i, model in enumerate([model1, model2, model3]):\n", + " model.fit(X_train, y_train)\n", + " \n", + " train_accuracy=model.score(X_train, y_train)\n", + " test_accuracy=model.score(X_test, y_test)\n", + " \n", + " model_name=model_name_list[i]\n", + " results.loc[model_name, :] =[train_accuracy, test_accuracy]\n", + " return results\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import numpy as np\\nimport os\\nimport pandas as pd\\nfrom sklearn.model_selection import train_test_split\\nfrom sklearn.ensemble import RandomForestClassifier'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import numpy as np\\nimport os\\nimport pandas as pd\\nfrom sklearn.model_selection import train_test_split\\nfrom sklearn.ensemble import RandomForestClassifier'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import numpy as np\\nimport os\\nimport pandas as pd\\nfrom sklearn.model_selection import train_test_split\\nfrom sklearn.ensemble import RandomForestClassifier'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + " if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import numpy as np\\nimport os\\nimport pandas as pd\\nfrom sklearn.model_selection import train_test_split\\nfrom sklearn.ensemble import RandomForestClassifier'); }\n", + " " + ], + "text/plain": [ + "<IPython.core.display.Javascript object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>train_accuracy</th>\n", + " <th>test_accuracy</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <td>Random Forest</td>\n", + " <td>0.949323</td>\n", + " <td>0.91983</td>\n", + " </tr>\n", + " <tr>\n", + " <td>Extra Trees</td>\n", + " <td>0.949323</td>\n", + " <td>0.919339</td>\n", + " </tr>\n", + " <tr>\n", + " <td>Gradient Boosted</td>\n", + " <td>0.924149</td>\n", + " <td>0.925993</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " train_accuracy test_accuracy\n", + "Random Forest 0.949323 0.91983\n", + "Extra Trees 0.949323 0.919339\n", + "Gradient Boosted 0.924149 0.925993" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = evaluate(X_train, X_test, y_train, y_test)\n", + "results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}