2337 lines (2336 with data), 172.9 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"hide": true
},
"source": [
"import sklearn\n",
"import os\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"# ^^^ pyforest auto-imports - don't write above this line\n",
"# Clinical Deterioration Prediction Model - Logistic Regression\n",
"$$\n",
"\\renewcommand{\\like}{{\\cal L}}\n",
"\\renewcommand{\\loglike}{{\\ell}}\n",
"\\renewcommand{\\err}{{\\cal E}}\n",
"\\renewcommand{\\dat}{{\\cal D}}\n",
"\\renewcommand{\\hyp}{{\\cal H}}\n",
"\\renewcommand{\\Ex}[2]{E_{#1}[#2]}\n",
"\\renewcommand{\\x}{{\\mathbf x}}\n",
"\\renewcommand{\\v}[1]{{\\mathbf #1}}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data\n",
"\n",
"The final dataset used for the inferential statistics project includes unique ICU admission of 46,234 patients’ demographic (age), vital (blood pressure, heart rate, body temperature, and Glasgow Comma Scale), underlying conditions (HIV, metastatic cancer, and hematologic malignancy), admission type (scheduled surgical, medical, or unscheduled surgical), renal (urinary output, and Blood Urea Nitrogen), and others (serum bicarbonate level, sodium level, potassium level, and bilirubin level) data. This dataset is build based on the commonly used mortality prediction tool, Simplified Acute Physiology Score II (SAPSII)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\abebu\\\\Dropbox\\\\Data Science\\\\Projects\\\\Capstone Project 1\\\\Potential Projects\\\\9. MIMIC\\\\Machine Learning\\\\Clinical-Deterioration-Prediction-Model--Logistic-Regression'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"os.getcwd()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"os.chdir(\"C://Users/abebu/Google Drive/mimic-iii-clinical-database-1.4\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os\\nimport pandas as pd'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SUBJECT_ID</th>\n",
" <th>HADM_ID</th>\n",
" <th>ICUSTAY_ID</th>\n",
" <th>los</th>\n",
" <th>hdeath</th>\n",
" <th>death</th>\n",
" <th>admission</th>\n",
" <th>ud</th>\n",
" <th>bun</th>\n",
" <th>Bicarbonate</th>\n",
" <th>...</th>\n",
" <th>Sodium</th>\n",
" <th>Temp</th>\n",
" <th>Bilirubin</th>\n",
" <th>WBC</th>\n",
" <th>hr</th>\n",
" <th>gcs</th>\n",
" <th>bp</th>\n",
" <th>AGE</th>\n",
" <th>UO</th>\n",
" <th>saps2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>268</td>\n",
" <td>110404</td>\n",
" <td>280836</td>\n",
" <td>3.2490</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" <td>0.0</td>\n",
" <td>6.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>11.0</td>\n",
" <td>26.0</td>\n",
" <td>13.0</td>\n",
" <td>12.0</td>\n",
" <td>0.0</td>\n",
" <td>82.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>269</td>\n",
" <td>106296</td>\n",
" <td>206613</td>\n",
" <td>3.2788</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8</td>\n",
" <td>17.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>7.0</td>\n",
" <td>0.0</td>\n",
" <td>37.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>270</td>\n",
" <td>188028</td>\n",
" <td>220345</td>\n",
" <td>2.8939</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>11.0</td>\n",
" <td>0.0</td>\n",
" <td>13.0</td>\n",
" <td>18.0</td>\n",
" <td>0.0</td>\n",
" <td>45.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>271</td>\n",
" <td>173727</td>\n",
" <td>249196</td>\n",
" <td>2.0600</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>7.0</td>\n",
" <td>0.0</td>\n",
" <td>24.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>272</td>\n",
" <td>164716</td>\n",
" <td>210407</td>\n",
" <td>1.6202</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>12.0</td>\n",
" <td>0.0</td>\n",
" <td>28.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 22 columns</p>\n",
"</div>"
],
"text/plain": [
" SUBJECT_ID HADM_ID ICUSTAY_ID los hdeath death admission ud \\\n",
"0 268 110404 280836 3.2490 1 1 8 0.0 \n",
"1 269 106296 206613 3.2788 0 0 8 17.0 \n",
"2 270 188028 220345 2.8939 0 0 0 0.0 \n",
"3 271 173727 249196 2.0600 0 0 8 0.0 \n",
"4 272 164716 210407 1.6202 0 0 8 0.0 \n",
"\n",
" bun Bicarbonate ... Sodium Temp Bilirubin WBC hr gcs bp \\\n",
"0 6.0 0.0 ... 0.0 0.0 0.0 0.0 11.0 26.0 13.0 \n",
"1 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 5.0 \n",
"2 0.0 0.0 ... 0.0 3.0 0.0 0.0 11.0 0.0 13.0 \n",
"3 0.0 0.0 ... 0.0 3.0 0.0 0.0 0.0 0.0 0.0 \n",
"4 0.0 0.0 ... 0.0 3.0 0.0 0.0 0.0 0.0 5.0 \n",
"\n",
" AGE UO saps2 \n",
"0 12.0 0.0 82.0 \n",
"1 7.0 0.0 37.0 \n",
"2 18.0 0.0 45.0 \n",
"3 7.0 0.0 24.0 \n",
"4 12.0 0.0 28.0 \n",
"\n",
"[5 rows x 22 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"saps = pd.read_csv(\"saps_ts.csv\", header=0, index_col=0)\n",
"saps.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os\\nimport pandas as pd'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"saps = pd.get_dummies(saps, columns=['Potassium', 'Sodium',\n",
" 'WBC', 'hr', 'bp']) "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"saps.to_csv('saps_ts.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, run logistics regression using saps2 (the sum of all features) as explantory variable and death at ICU (hdeath - hospital death) as target variable. "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dimensions of y before reshaping: (61117,)\n",
"Dimensions of X before reshaping: (61117,)\n",
"Dimensions of y after reshaping: (61117,)\n",
"Dimensions of X after reshaping: (61117, 1)\n"
]
}
],
"source": [
"# Create arrays for features and target variable\n",
"y = saps['hdeath'].values\n",
"X = saps['saps2'].values\n",
"\n",
"# Print the dimensions of X and y before reshaping\n",
"print(\"Dimensions of y before reshaping: {}\".format(y.shape))\n",
"print(\"Dimensions of X before reshaping: {}\".format(X.shape))\n",
"\n",
"# Reshape X and y\n",
"#y = y.reshape(-1, 1)\n",
"X = X.reshape(-1, 1)\n",
"\n",
"# Print the dimensions of X and y after reshaping\n",
"print(\"Dimensions of y after reshaping: {}\".format(y.shape))\n",
"print(\"Dimensions of X after reshaping: {}\".format(X.shape))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"* Split the data into a training and test (hold-out) set\n",
"* Train on the training set, and test for accuracy on the testing set"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Accuracy: 0.9077164735912037\n",
"Testing Accuracy: 0.9036649214659686\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"# Split the data into a training and test set.\n",
"Xlr, Xtestlr, ylr, ytestlr = train_test_split(X, y,random_state=5)\n",
"\n",
"clf = LogisticRegression(solver='lbfgs')\n",
"# Fit the model on the trainng data.\n",
"clf.fit(Xlr, ylr)\n",
"\n",
"# Print the accuracy\n",
"print('Training Accuracy: {}'.format((accuracy_score(clf.predict(Xlr), ylr))))\n",
"print('Testing Accuracy: {}'.format((accuracy_score(clf.predict(Xtestlr), ytestlr))))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hyperparameter Tuning "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model has some hyperparameters we can tune for hopefully better performance. In Logistic Regression, the most important parameter to tune is the *regularization parameter* `C`. Note that the regularization parameter is not always part of the logistic regression model. The regularization parameter is used to control for unlikely high regression coefficients, and in other cases can be used when data is sparse, as a method of feature selection. We may not need this for our model but worth checking. \n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import KFold\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"def cv_score(clf, x, y, score_func=accuracy_score):\n",
" result = 0\n",
" nfold = 5\n",
" for train, test in KFold(nfold).split(x): # split data into train/test groups, 5 times\n",
" clf.fit(x[train], y[train]) # fit\n",
" result += score_func(clf.predict(x[test]), y[test]) # evaluate score function on held-out data\n",
" return result / nfold # average"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9079345258458951\n"
]
}
],
"source": [
"clf = LogisticRegression(solver='lbfgs')\n",
"score = cv_score(clf, Xlr, ylr)\n",
"print(score)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the `cv_score` function (5-fold cross validation) for a basic logistic regression model without regularization,the score on the held-out data (test data) is `0.908`, `91%`. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"score: 0.9079345258458951, C:0.01\n",
"score: 0.9079345258458951, C:0.1\n",
"score: 0.9079345258458951, C:1\n",
"score: 0.9079345258458951, C:10\n",
"score: 0.9079345258458951, C:100\n",
"\n",
"The Maximum score with training data is 0.9079345258458951 for a C value of 0.01.\n"
]
}
],
"source": [
"#the grid of parameters to search over\n",
"Cs = [0.01, 0.1, 1, 10, 100]\n",
"max_score = 0\n",
"for c in Cs:\n",
" clf=LogisticRegression(solver='lbfgs', C=c)\n",
" score = cv_score(clf, Xlr, ylr)\n",
" print(f'score: {score}, C:{c}')\n",
" if score > max_score:\n",
" max_score = score\n",
" max_C = c\n",
"print(f'\\nThe Maximum score with training data is {max_score} for a C value of {max_C}.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the training set the best model parameter is 0.9079345258458951 for a C value of 0.01."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy with the test data is 0.9036649214659686.\n"
]
}
],
"source": [
"clf =LogisticRegression(solver='lbfgs', C=max_C)\n",
"# Fit the model on teh training data\n",
"clf.fit(Xlr, ylr)\n",
"# Print the accuracy from the test data\n",
"print(f'The accuracy with the test data is {accuracy_score(clf.predict(Xtestlr), ytestlr)} for a C value of {max_C}.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the model with C=0.01 gives as the same accuracy results on the test data as the deafult. This is not always the case hence important to experment with the hyperparameters that works best with new data. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Grid Search"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best score on training data: 0.9080219037022492 using {'C': 0.1, 'penalty': 'l2', 'solver': 'liblinear'}\n"
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"model = LogisticRegression(max_iter=1000)\n",
"\n",
"# define parameter values\n",
"solvers = ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']\n",
"penalty = ['none', 'l1', 'l2', 'elasticnet']\n",
"c_values = [100, 10, 1.0, 0.1, 0.01]\n",
"\n",
"# define grid search\n",
"grid = dict(solver=solvers,penalty=penalty,C=c_values)\n",
"grid_search = GridSearchCV(estimator=model, param_grid=grid, n_jobs=-1, cv=5, scoring='accuracy', error_score=0)\n",
"grid_result = grid_search.fit(Xlr, ylr)\n",
"\n",
"# summarize results\n",
"print(f\"Best score on training data: {grid_result.best_score_} using {grid_result.best_params_}\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score on test data: 0.9044502617801047\n"
]
}
],
"source": [
"print(f'Score on test data: {accuracy_score(grid_result.predict(Xtestlr), ytestlr)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It gives a diffrent best value of C - this time 0.1. The GridSearchCV performs slightly better on test data (0.9036 vs 0.9044), almost the same. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's first set some code up for classification that we will need for further discussion on the math. We first set up a function `cv_optimize` which takes a classifier `clf`, a grid of hyperparameters (such as a complexity parameter or regularization parameter) implemented as a dictionary `parameters`, a training set (as a samples x features array) `Xtrain`, and a set of labels `ytrain`. The code takes the traning set, splits it into `n_folds` parts, sets up `n_folds` folds, and carries out a cross-validation by splitting the training set into a training and validation section for each foldfor us. It prints the best value of the parameters, and retuens the best classifier to us."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def cv_optimize(clf, parameters, Xtrain, ytrain, n_folds=5):\n",
" gs = sklearn.model_selection.GridSearchCV(clf, param_grid=parameters, cv=n_folds)\n",
" gs.fit(Xtrain, ytrain)\n",
" print(\"BEST PARAMS\", gs.best_params_)\n",
" best = gs.best_estimator_\n",
" return best"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We then use this best classifier to fit the entire training set. This is done inside the `do_classify` function which takes a dataframe `indf` as input. It takes the columns in the list `featurenames` as the features used to train the classifier. The column `targetname` sets the target. The classification is done by setting those samples for which `targetname` has value `target1val` to the value 1, and all others to 0. We split the dataframe into 80% training and 20% testing by default, standardizing the dataset if desired. (Standardizing a data set involves scaling the data so that it has 0 mean and is described in units of its standard deviation. We then train the model on the training set using cross-validation. Having obtained the best classifier using `cv_optimize`, we retrain on the entire training set and calculate the training and testing accuracy, which we print. We return the split data and the trained classifier."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"hide": true
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"def do_classify(clf, parameters, indf, featurenames, targetname, target1val, standardize=False, train_size=0.8):\n",
" subdf=indf[featurenames]\n",
" if standardize:\n",
" subdfstd=(subdf - subdf.mean())/subdf.std()\n",
" else:\n",
" subdfstd=subdf\n",
" X=subdfstd.values\n",
" y=(indf[targetname].values==target1val)*1\n",
" Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, train_size=train_size)\n",
" clf = cv_optimize(clf, parameters, Xtrain, ytrain)\n",
" clf=clf.fit(Xtrain, ytrain)\n",
" training_accuracy = clf.score(Xtrain, ytrain)\n",
" test_accuracy = clf.score(Xtest, ytest)\n",
" print(\"Accuracy on training data: {:0.2f}\".format(training_accuracy))\n",
" print(\"Accuracy on test data: {:0.2f}\".format(test_accuracy))\n",
" return clf, Xtrain, ytrain, Xtest, ytest"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os\\nimport sklearn\\nimport matplotlib.pyplot as plt\\nimport pandas as pd'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"BEST PARAMS {'C': 0.01}\n",
"Accuracy on training data: 0.91\n",
"Accuracy on test data: 0.90\n"
]
}
],
"source": [
"clf_l, Xtrain_l, ytrain_l, Xtest_l, ytest_l = do_classify(LogisticRegression(solver='lbfgs'), \n",
" {\"C\": [0.01, 0.1, 1, 10, 100]}, \n",
" saps, ['saps2'], 'hdeath',1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Standardize"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {
"hide": true
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"def do_classify(clf, parameters, indf, featurenames, targetname, target1val, standardize=True, train_size=0.8):\n",
" subdf=indf[featurenames]\n",
" if standardize:\n",
" subdfstd=(subdf - subdf.mean())/subdf.std()\n",
" else:\n",
" subdfstd=subdf\n",
" X=subdfstd.values\n",
" y=(indf[targetname].values==target1val)*1\n",
" Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, train_size=train_size)\n",
" clf = cv_optimize(clf, parameters, Xtrain, ytrain)\n",
" clf=clf.fit(Xtrain, ytrain)\n",
" training_accuracy = clf.score(Xtrain, ytrain)\n",
" test_accuracy = clf.score(Xtest, ytest)\n",
" print(\"Accuracy on training data: {:0.2f}\".format(training_accuracy))\n",
" print(\"Accuracy on test data: {:0.2f}\".format(test_accuracy))\n",
" return clf, Xtrain, ytrain, Xtest, ytest"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport seaborn as sns\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"BEST PARAMS {'C': 0.01}\n",
"Accuracy on training data: 0.91\n",
"Accuracy on test data: 0.91\n"
]
}
],
"source": [
"clf_l, Xtrain_l, ytrain_l, Xtest_l, ytest_l = do_classify(LogisticRegression(solver='lbfgs'), \n",
" {\"C\": [0.01, 0.1, 1, 10, 100]}, \n",
" saps, ['saps2'], 'hdeath',1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ROC Curve\n",
"`Plotting an ROC curve - receiver operating characteristic `"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"def make_roc(name, clf, ytest, xtest, ax=None, labe=5, proba=True, skip=0):\n",
" initial=False\n",
" if not ax:\n",
" ax=plt.gca()\n",
" initial=True\n",
" if proba:\n",
" fpr, tpr, thresholds=roc_curve(ytest, clf.predict_proba(xtest)[:,1])\n",
" else:\n",
" fpr, tpr, thresholds=roc_curve(ytest, clf.decision_function(xtest))\n",
" roc_auc = auc(fpr, tpr)\n",
" if skip:\n",
" l=fpr.shape[0]\n",
" ax.plot(fpr[0:l:skip], tpr[0:l:skip], 'o-', alpha=0.8, label='ROC curve for %s (area = %0.2f)' % (name, roc_auc))\n",
" else:\n",
" ax.plot(fpr, tpr, '.-', alpha=0.8, label='ROC curve for %s (area = %0.2f)' % (name, roc_auc))\n",
" label_kwargs = {}\n",
" label_kwargs['bbox'] = dict(\n",
" boxstyle='round,pad=0.1', alpha=0.1,\n",
" )\n",
" for k in range(0, fpr.shape[0],labe):\n",
" #from https://gist.github.com/podshumok/c1d1c9394335d86255b8\n",
" threshold = str(np.round(thresholds[k], 2))\n",
" ax.annotate(threshold, (fpr[k], tpr[k]), **label_kwargs)\n",
" if initial:\n",
" ax.plot([0, 1], [0, 1], 'k--')\n",
" ax.set_xlim([0.0, 1.0])\n",
" ax.set_ylim([0.0, 1.05])\n",
" ax.set_xlabel('False Positive Rate')\n",
" ax.set_ylabel('True Positive Rate')\n",
" ax.set_title('ROC')\n",
" ax.legend(loc=\"lower right\")\n",
" return ax"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os\\nimport sklearn\\nimport matplotlib.pyplot as plt\\nimport pandas as pd'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import os\\nimport sklearn\\nimport matplotlib.pyplot as plt\\nimport pandas as pd'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import roc_curve, auc\n",
"plt.figure(figsize=(10,6))\n",
"ax=make_roc(\"logistic\", clf_l, ytest_l, Xtest_l, labe=200, skip=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cross Validation Score\n",
"\n",
"we should evaluate the performance of an algorithm rigorously by using resampling approaches (e.g. 100 times 5-fold cross-validation) to get some measurement of the variability in the performance of the algorithm. Maybe on a particular hold-out set, two algorithms have very similar performance but the variability of their estimates is massively different. That has serious implication on when we deploy our model in the future or use it to draw conclusion about future performance."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC scores computed using 5-fold cross-validation: [0.84631633 0.81585015 0.85512388 0.85891427 0.84046901]\n"
]
}
],
"source": [
"# Import necessary modules\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"# Compute cross-validated AUC scores: cv_auc\n",
"cv_auc = cross_val_score(clf_l, Xtest_l, ytest_l.ravel(), cv=5, scoring='roc_auc')\n",
"\n",
"# Print list of AUC scores\n",
"print(\"AUC scores computed using 5-fold cross-validation: {}\".format(cv_auc))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run class_weight "
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['SUBJECT_ID',\n",
" 'HADM_ID',\n",
" 'ICUSTAY_ID',\n",
" 'los',\n",
" 'hdeath',\n",
" 'death',\n",
" 'admission',\n",
" 'ud',\n",
" 'bun',\n",
" 'Bicarbonate',\n",
" 'ventilation',\n",
" 'Temp',\n",
" 'Bilirubin',\n",
" 'gcs',\n",
" 'AGE',\n",
" 'UO',\n",
" 'saps2',\n",
" 'Potassium_0.0',\n",
" 'Potassium_3.0',\n",
" 'Sodium_0.0',\n",
" 'Sodium_1.0',\n",
" 'Sodium_5.0',\n",
" 'WBC_0.0',\n",
" 'WBC_3.0',\n",
" 'hr_0.0',\n",
" 'hr_2.0',\n",
" 'hr_4.0',\n",
" 'hr_7.0',\n",
" 'hr_11.0',\n",
" 'bp_0.0',\n",
" 'bp_2.0',\n",
" 'bp_5.0',\n",
" 'bp_13.0']"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(saps.columns)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dimensions of y before reshaping: (61117,)\n",
"Dimensions of X before reshaping: (61117, 26)\n",
"Dimensions of y after reshaping: (61117, 1)\n",
"Dimensions of X after reshaping: (61117, 26)\n"
]
}
],
"source": [
"# Create arrays for features and target variable\n",
"y = saps['hdeath'].values\n",
"X = saps[['admission',\n",
" 'ud',\n",
" 'bun',\n",
" 'Bicarbonate',\n",
" 'ventilation',\n",
" 'Temp',\n",
" 'Bilirubin',\n",
" 'gcs',\n",
" 'AGE',\n",
" 'UO',\n",
" 'Potassium_0.0',\n",
" 'Potassium_3.0',\n",
" 'Sodium_0.0',\n",
" 'Sodium_1.0',\n",
" 'Sodium_5.0',\n",
" 'WBC_0.0',\n",
" 'WBC_3.0',\n",
" 'hr_0.0',\n",
" 'hr_2.0',\n",
" 'hr_4.0',\n",
" 'hr_7.0',\n",
" 'hr_11.0',\n",
" 'bp_0.0',\n",
" 'bp_2.0',\n",
" 'bp_5.0',\n",
" 'bp_13.0']].values\n",
"# Print the dimensions of X and y before reshaping\n",
"print(\"Dimensions of y before reshaping: {}\".format(y.shape))\n",
"print(\"Dimensions of X before reshaping: {}\".format(X.shape))\n",
"\n",
"# Reshape X and y\n",
"y = y.reshape(-1, 1)\n",
"#X = X.reshape(-1, 1)\n",
"\n",
"# Print the dimensions of X and y after reshaping\n",
"print(\"Dimensions of y after reshaping: {}\".format(y.shape))\n",
"print(\"Dimensions of X after reshaping: {}\".format(X.shape))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Accuracy: 0.9216135436437812\n",
"Testing Accuracy: 0.918913612565445\n"
]
}
],
"source": [
"# Split the data into a training and test set.\n",
"Xlr, Xtestlr, ylr, ytestlr = train_test_split(X, y.ravel(),random_state=5)\n",
"\n",
"clf = LogisticRegression(solver='lbfgs', max_iter=1000)\n",
"# Fit the model on the trainng data.\n",
"clf.fit(Xlr, ylr)\n",
"\n",
"# Print the accuracy\n",
"print('Training Accuracy: {}'.format((accuracy_score(clf.predict(Xlr), ylr))))\n",
"print('Testing Accuracy: {}'.format((accuracy_score(clf.predict(Xtestlr), ytestlr))))\n"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"def cv_score(clf, x, y, score_func=accuracy_score):\n",
" result = 0\n",
" nfold = 5\n",
" for train, test in KFold(nfold).split(x): # split data into train/test groups, 5 times\n",
" clf.fit(x[train], y[train]) # fit\n",
" result += score_func(clf.predict(x[test]), y[test]) # evaluate score function on held-out data\n",
" return result / nfold # average"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1],\n",
" [0],\n",
" [0],\n",
" ...,\n",
" [0],\n",
" [0],\n",
" [0]], dtype=int64)"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Int64Index: 61117 entries, 0 to 61116\n",
"Data columns (total 33 columns):\n",
"SUBJECT_ID 61117 non-null int64\n",
"HADM_ID 61117 non-null int64\n",
"ICUSTAY_ID 61117 non-null int64\n",
"los 61117 non-null float64\n",
"hdeath 61117 non-null int64\n",
"death 61117 non-null int64\n",
"admission 61117 non-null int64\n",
"ud 61117 non-null float64\n",
"bun 61117 non-null float64\n",
"Bicarbonate 61117 non-null float64\n",
"ventilation 61117 non-null float64\n",
"Temp 61117 non-null float64\n",
"Bilirubin 61117 non-null float64\n",
"gcs 61117 non-null float64\n",
"AGE 61117 non-null float64\n",
"UO 61117 non-null float64\n",
"saps2 61117 non-null float64\n",
"Potassium_0.0 61117 non-null int64\n",
"Potassium_3.0 61117 non-null int64\n",
"Sodium_0.0 61117 non-null int64\n",
"Sodium_1.0 61117 non-null int64\n",
"Sodium_5.0 61117 non-null int64\n",
"WBC_0.0 61117 non-null int64\n",
"WBC_3.0 61117 non-null int64\n",
"hr_0.0 61117 non-null int64\n",
"hr_2.0 61117 non-null int64\n",
"hr_4.0 61117 non-null int64\n",
"hr_7.0 61117 non-null int64\n",
"hr_11.0 61117 non-null int64\n",
"bp_0.0 61117 non-null int64\n",
"bp_2.0 61117 non-null int64\n",
"bp_5.0 61117 non-null int64\n",
"bp_13.0 61117 non-null int64\n",
"dtypes: float64(11), int64(22)\n",
"memory usage: 15.9 MB\n"
]
}
],
"source": [
"saps.info()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"saps['Potassium_0.0'] = saps['Potassium_0.0'].astype(np.int64)\n",
"saps['Potassium_3.0'] = saps['Potassium_3.0'].astype(np.int64)\n",
"saps['Sodium_0.0'] = saps['Sodium_0.0'].astype(np.int64)\n",
"saps['Sodium_1.0'] = saps['Sodium_1.0'].astype(np.int64)\n",
"saps['Sodium_5.0'] = saps['Sodium_5.0'].astype(np.int64)\n",
"saps['WBC_0.0'] = saps['WBC_0.0'].astype(np.int64)\n",
"saps['WBC_3.0'] = saps['WBC_3.0'].astype(np.int64)\n",
"saps['hr_0.0'] = saps['hr_0.0'].astype(np.int64)\n",
"saps['hr_2.0'] = saps['hr_2.0'].astype(np.int64)\n",
"saps['hr_4.0'] = saps['hr_4.0'].astype(np.int64)\n",
"saps['hr_7.0'] = saps['hr_7.0'].astype(np.int64)\n",
"saps['hr_11.0'] = saps['hr_11.0'].astype(np.int64)\n",
"saps['bp_0.0'] = saps['bp_0.0'].astype(np.int64)\n",
"saps['bp_2.0'] = saps['bp_2.0'].astype(np.int64)\n",
"saps['bp_5.0'] = saps['bp_5.0'].astype(np.int64)\n",
"saps['bp_13.0'] = saps['bp_13.0'].astype(np.int64)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9214825696009912\n"
]
}
],
"source": [
"score = cv_score(clf, Xlr, ylr)\n",
"print(score)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"score: 0.921002639409019, C:0.01\n",
"score: 0.9214171293342783, C:0.1\n",
"score: 0.9214825696009912, C:1\n",
"score: 0.9215043846097171, C:10\n",
"score: 0.9215262019981758, C:100\n",
"\n",
"The Maximum score with training data is 0.9215262019981758 for a C value of 100.\n"
]
}
],
"source": [
"#the grid of parameters to search over\n",
"Cs = [0.01, 0.1, 1, 10, 100]\n",
"max_score = 0\n",
"for c in Cs:\n",
" clf=LogisticRegression(solver='lbfgs', max_iter=1000, C=c)\n",
" score = cv_score(clf, Xlr, ylr)\n",
" print(f'score: {score}, C:{c}')\n",
" if score > max_score:\n",
" max_score = score\n",
" max_C = c\n",
"print(f'\\nThe Maximum score with training data is {max_score} for a C value of {max_C}.')"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy with the test data is 0.919044502617801.\n"
]
}
],
"source": [
"clf =LogisticRegression(solver='lbfgs', max_iter=1000, C=max_C)\n",
"# Fit the model on teh training data\n",
"clf.fit(Xlr, ylr)\n",
"# Print the accuracy from the test data\n",
"print(f'The accuracy with the test data is {accuracy_score(clf.predict(Xtestlr), ytestlr)}.')"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best score on training data: 0.9215480943342714 using {'C': 0.1, 'penalty': 'l1', 'solver': 'liblinear'}\n"
]
}
],
"source": [
"model = LogisticRegression(max_iter=1000)\n",
"\n",
"# define parameter values\n",
"solvers = ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']\n",
"penalty = ['none', 'l1', 'l2', 'elasticnet']\n",
"c_values = [100, 10, 1.0, 0.1, 0.01]\n",
"\n",
"# define grid search\n",
"grid = dict(solver=solvers,penalty=penalty,C=c_values)\n",
"grid_search = GridSearchCV(estimator=model, param_grid=grid, n_jobs=-1, cv=5, scoring='accuracy', error_score=0)\n",
"grid_result = grid_search.fit(Xlr, ylr)\n",
"\n",
"# summarize results\n",
"print(f\"Best score on training data: {grid_result.best_score_} using {grid_result.best_params_}\")"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score on test data: 0.918782722513089\n"
]
}
],
"source": [
"print(f'Score on test data: {accuracy_score(grid_result.predict(Xtestlr), ytestlr)}')"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"def cv_optimize(clf, parameters, Xtrain, ytrain, n_folds=5):\n",
" gs = sklearn.model_selection.GridSearchCV(clf, param_grid=parameters, cv=n_folds)\n",
" gs.fit(Xtrain, ytrain)\n",
" print(\"BEST PARAMS\", gs.best_params_)\n",
" best = gs.best_estimator_\n",
" return best"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"def do_classify(clf, parameters, indf, featurenames, targetname, target1val, standardize=False, train_size=0.8):\n",
" subdf=indf[featurenames]\n",
" if standardize:\n",
" subdfstd=(subdf - subdf.mean())/subdf.std()\n",
" else:\n",
" subdfstd=subdf\n",
" X=subdfstd.values\n",
" y=(indf[targetname].values==target1val)*1\n",
" Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, train_size=train_size)\n",
" clf = cv_optimize(clf, parameters, Xtrain, ytrain)\n",
" clf=clf.fit(Xtrain, ytrain)\n",
" training_accuracy = clf.score(Xtrain, ytrain)\n",
" test_accuracy = clf.score(Xtest, ytest)\n",
" print(\"Accuracy on training data: {:0.2f}\".format(training_accuracy))\n",
" print(\"Accuracy on test data: {:0.2f}\".format(test_accuracy))\n",
" return clf, Xtrain, ytrain, Xtest, ytest"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"BEST PARAMS {'C': 10}\n",
"Accuracy on training data: 0.92\n",
"Accuracy on test data: 0.92\n"
]
}
],
"source": [
"clf_l, Xtrain_l, ytrain_l, Xtest_l, ytest_l = do_classify(LogisticRegression(solver='lbfgs', max_iter=2000), \n",
" {\"C\": [0.01, 0.1, 1, 10, 100]}, \n",
" saps, ['admission',\n",
" 'ud',\n",
" 'bun',\n",
" 'Bicarbonate',\n",
" 'ventilation',\n",
" 'Temp',\n",
" 'Bilirubin',\n",
" 'gcs',\n",
" 'AGE',\n",
" 'UO',\n",
" 'Potassium_0.0',\n",
" 'Potassium_3.0',\n",
" 'Sodium_0.0',\n",
" 'Sodium_1.0',\n",
" 'Sodium_5.0',\n",
" 'WBC_0.0',\n",
" 'WBC_3.0',\n",
" 'hr_0.0',\n",
" 'hr_2.0',\n",
" 'hr_4.0',\n",
" 'hr_7.0',\n",
" 'hr_11.0',\n",
" 'bp_0.0',\n",
" 'bp_2.0',\n",
" 'bp_5.0',\n",
" 'bp_13.0'], 'hdeath',1)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"def make_roc(name, clf, ytest, xtest, ax=None, labe=5, proba=True, skip=0):\n",
" initial=False\n",
" if not ax:\n",
" ax=plt.gca()\n",
" initial=True\n",
" if proba:\n",
" fpr, tpr, thresholds=roc_curve(ytest, clf.predict_proba(xtest)[:,1])\n",
" else:\n",
" fpr, tpr, thresholds=roc_curve(ytest, clf.decision_function(xtest))\n",
" roc_auc = auc(fpr, tpr)\n",
" if skip:\n",
" l=fpr.shape[0]\n",
" ax.plot(fpr[0:l:skip], tpr[0:l:skip], 'o-', alpha=0.8, label='ROC curve for %s (area = %0.2f)' % (name, roc_auc))\n",
" else:\n",
" ax.plot(fpr, tpr, '.-', alpha=0.8, label='ROC curve for %s (area = %0.2f)' % (name, roc_auc))\n",
" label_kwargs = {}\n",
" label_kwargs['bbox'] = dict(\n",
" boxstyle='round,pad=0.1', alpha=0.1,\n",
" )\n",
" for k in range(0, fpr.shape[0],labe):\n",
" #from https://gist.github.com/podshumok/c1d1c9394335d86255b8\n",
" threshold = str(np.round(thresholds[k], 2))\n",
" ax.annotate(threshold, (fpr[k], tpr[k]), **label_kwargs)\n",
" if initial:\n",
" ax.plot([0, 1], [0, 1], 'k--')\n",
" ax.set_xlim([0.0, 1.0])\n",
" ax.set_ylim([0.0, 1.05])\n",
" ax.set_xlabel('False Positive Rate')\n",
" ax.set_ylabel('True Positive Rate')\n",
" ax.set_title('ROC')\n",
" ax.legend(loc=\"lower right\")\n",
" return ax"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,6))\n",
"ax=make_roc(\"logistic\", clf_l, ytest_l, Xtest_l, labe=200, skip=2)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC scores computed using 5-fold cross-validation: [0.86265553 0.88734978 0.86138377 0.87400611 0.85694651]\n"
]
}
],
"source": [
"# Compute cross-validated AUC scores: cv_auc\n",
"cv_auc = cross_val_score(clf_l, Xtest_l, ytest_l.ravel(), cv=5, scoring='roc_auc')\n",
"\n",
"# Print list of AUC scores\n",
"print(\"AUC scores computed using 5-fold cross-validation: {}\".format(cv_auc))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#standard scaling \n",
"# xgboost\n",
"#\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Weighted Logistic Regression for Imbalanced Dataset"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
" if (window._pyforest_update_imports_cell) { window._pyforest_update_imports_cell('import sklearn\\nimport os\\nimport pandas as pd\\nimport numpy as np\\nimport matplotlib.pyplot as plt'); }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"survived=saps.saps2.loc[saps.hdeath==0]\n",
"deceased=saps.saps2.loc[saps.hdeath==1]\n",
"\n",
"_ = plt.hist(survived, bins=30, alpha=0.5, label='ICU Patients Survived')\n",
"_ = plt.hist(deceased, bins=30, alpha=0.5, label='ICU Patients Deceased')\n",
"_ = plt.xlabel('SAPSII Total Score')\n",
"_ = plt.ylabel('Frequency')\n",
"_ = plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[13384 199]\n",
" [ 1038 659]]\n",
" precision recall f1-score support\n",
"\n",
" 0 0.93 0.99 0.96 13583\n",
" 1 0.77 0.39 0.52 1697\n",
"\n",
" accuracy 0.92 15280\n",
" macro avg 0.85 0.69 0.74 15280\n",
"weighted avg 0.91 0.92 0.91 15280\n",
"\n"
]
}
],
"source": [
"# Generate the confusion matrix and classification report\n",
"# Import necessary modules\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.metrics import confusion_matrix\n",
"ypred = clf.predict(Xtestlr)\n",
"# Generate the confusion matrix and classification report\n",
"print(confusion_matrix(ytestlr, ypred))\n",
"print(classification_report(ytestlr, ypred))"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy Score: 0.9013089005235602\n",
"[[13567 16]\n",
" [ 1492 205]]\n",
" precision recall f1-score support\n",
"\n",
" 0 0.90 1.00 0.95 13583\n",
" 1 0.93 0.12 0.21 1697\n",
"\n",
" accuracy 0.90 15280\n",
" macro avg 0.91 0.56 0.58 15280\n",
"weighted avg 0.90 0.90 0.87 15280\n",
"\n",
"Area Under Curve: 0.5598117356217265\n",
"Recall score: 0.12080141426045964\n"
]
}
],
"source": [
"from sklearn.metrics import roc_auc_score, recall_score\n",
"\n",
"# define class weights (11%, 89%)\n",
"w = {0:89, 1:11}\n",
"\n",
"# define model\n",
"clf2 = LogisticRegression(solver='lbfgs', max_iter=1000, class_weight=w)\n",
"# fit\n",
"clf2.fit(Xlr,ylr)\n",
"# test\n",
"ypred = clf2.predict(Xtestlr)\n",
"# performance\n",
"print(f'Accuracy Score: {accuracy_score(ytestlr,ypred)}')\n",
"print(confusion_matrix(ytestlr, ypred))\n",
"print(classification_report(ytestlr, ypred))\n",
"print(f'Area Under Curve: {roc_auc_score(ytestlr, ypred)}')\n",
"print(f'Recall score: {recall_score(ytestlr,ypred)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### XGBoost \n",
"\n",
"XGBoost Python api provides a method to assess the incremental performance by the incremental number of trees. It uses two arguments: “eval_set” — usually Train and Test sets — and the associated “eval_metric” to measure your error on these evaluation sets."
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"from xgboost import XGBClassifier\n",
"model = XGBClassifier(silent=False, \n",
" scale_pos_weight=1,\n",
" learning_rate=0.01, \n",
" colsample_bytree = 0.4,\n",
" subsample = 0.8,\n",
" objective='binary:logistic', \n",
" n_estimators=100, \n",
" reg_alpha = 0.3,\n",
" max_depth=3, \n",
" gamma=1)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\tvalidation_0-auc:0.805756\tvalidation_0-error:0.083622\tvalidation_1-auc:0.809636\tvalidation_1-error:0.087631\n",
"[1]\tvalidation_0-auc:0.806417\tvalidation_0-error:0.083535\tvalidation_1-auc:0.806968\tvalidation_1-error:0.087696\n",
"[2]\tvalidation_0-auc:0.833609\tvalidation_0-error:0.083535\tvalidation_1-auc:0.83371\tvalidation_1-error:0.087565\n",
"[3]\tvalidation_0-auc:0.847123\tvalidation_0-error:0.082968\tvalidation_1-auc:0.848982\tvalidation_1-error:0.086976\n",
"[4]\tvalidation_0-auc:0.841802\tvalidation_0-error:0.082968\tvalidation_1-auc:0.842443\tvalidation_1-error:0.086322\n",
"[5]\tvalidation_0-auc:0.843148\tvalidation_0-error:0.082902\tvalidation_1-auc:0.844063\tvalidation_1-error:0.086846\n",
"[6]\tvalidation_0-auc:0.85287\tvalidation_0-error:0.083535\tvalidation_1-auc:0.854493\tvalidation_1-error:0.087696\n",
"[7]\tvalidation_0-auc:0.852881\tvalidation_0-error:0.083557\tvalidation_1-auc:0.85407\tvalidation_1-error:0.087107\n",
"[8]\tvalidation_0-auc:0.854487\tvalidation_0-error:0.08347\tvalidation_1-auc:0.854802\tvalidation_1-error:0.087631\n",
"[9]\tvalidation_0-auc:0.855782\tvalidation_0-error:0.083208\tvalidation_1-auc:0.855388\tvalidation_1-error:0.087173\n",
"[10]\tvalidation_0-auc:0.859383\tvalidation_0-error:0.083797\tvalidation_1-auc:0.858915\tvalidation_1-error:0.087696\n",
"[11]\tvalidation_0-auc:0.858277\tvalidation_0-error:0.083841\tvalidation_1-auc:0.857677\tvalidation_1-error:0.087827\n",
"[12]\tvalidation_0-auc:0.85842\tvalidation_0-error:0.084604\tvalidation_1-auc:0.858429\tvalidation_1-error:0.088416\n",
"[13]\tvalidation_0-auc:0.857983\tvalidation_0-error:0.086219\tvalidation_1-auc:0.858302\tvalidation_1-error:0.090118\n",
"[14]\tvalidation_0-auc:0.859352\tvalidation_0-error:0.085215\tvalidation_1-auc:0.860531\tvalidation_1-error:0.089332\n",
"[15]\tvalidation_0-auc:0.860412\tvalidation_0-error:0.083797\tvalidation_1-auc:0.861723\tvalidation_1-error:0.087893\n",
"[16]\tvalidation_0-auc:0.862379\tvalidation_0-error:0.083971\tvalidation_1-auc:0.862877\tvalidation_1-error:0.087893\n",
"[17]\tvalidation_0-auc:0.862198\tvalidation_0-error:0.083688\tvalidation_1-auc:0.862819\tvalidation_1-error:0.087696\n",
"[18]\tvalidation_0-auc:0.86192\tvalidation_0-error:0.083601\tvalidation_1-auc:0.862644\tvalidation_1-error:0.087631\n",
"[19]\tvalidation_0-auc:0.861278\tvalidation_0-error:0.083666\tvalidation_1-auc:0.861946\tvalidation_1-error:0.087696\n",
"[20]\tvalidation_0-auc:0.862406\tvalidation_0-error:0.083688\tvalidation_1-auc:0.862919\tvalidation_1-error:0.087696\n",
"[21]\tvalidation_0-auc:0.865696\tvalidation_0-error:0.084342\tvalidation_1-auc:0.866523\tvalidation_1-error:0.089005\n",
"[22]\tvalidation_0-auc:0.865447\tvalidation_0-error:0.085019\tvalidation_1-auc:0.866014\tvalidation_1-error:0.089202\n",
"[23]\tvalidation_0-auc:0.86629\tvalidation_0-error:0.084713\tvalidation_1-auc:0.866522\tvalidation_1-error:0.089005\n",
"[24]\tvalidation_0-auc:0.86573\tvalidation_0-error:0.084321\tvalidation_1-auc:0.865647\tvalidation_1-error:0.088743\n",
"[25]\tvalidation_0-auc:0.866428\tvalidation_0-error:0.084691\tvalidation_1-auc:0.86653\tvalidation_1-error:0.089005\n",
"[26]\tvalidation_0-auc:0.866293\tvalidation_0-error:0.084495\tvalidation_1-auc:0.866744\tvalidation_1-error:0.088809\n",
"[27]\tvalidation_0-auc:0.86664\tvalidation_0-error:0.084626\tvalidation_1-auc:0.866838\tvalidation_1-error:0.089071\n",
"[28]\tvalidation_0-auc:0.866418\tvalidation_0-error:0.08443\tvalidation_1-auc:0.866702\tvalidation_1-error:0.088809\n",
"[29]\tvalidation_0-auc:0.865961\tvalidation_0-error:0.084648\tvalidation_1-auc:0.866056\tvalidation_1-error:0.089136\n",
"[30]\tvalidation_0-auc:0.867179\tvalidation_0-error:0.084408\tvalidation_1-auc:0.867253\tvalidation_1-error:0.088678\n",
"[31]\tvalidation_0-auc:0.867052\tvalidation_0-error:0.084626\tvalidation_1-auc:0.867089\tvalidation_1-error:0.088874\n",
"[32]\tvalidation_0-auc:0.867114\tvalidation_0-error:0.084757\tvalidation_1-auc:0.867105\tvalidation_1-error:0.089005\n",
"[33]\tvalidation_0-auc:0.867554\tvalidation_0-error:0.084866\tvalidation_1-auc:0.867242\tvalidation_1-error:0.089071\n",
"[34]\tvalidation_0-auc:0.867817\tvalidation_0-error:0.085084\tvalidation_1-auc:0.86753\tvalidation_1-error:0.089136\n",
"[35]\tvalidation_0-auc:0.86776\tvalidation_0-error:0.085237\tvalidation_1-auc:0.867284\tvalidation_1-error:0.089136\n",
"[36]\tvalidation_0-auc:0.867917\tvalidation_0-error:0.08539\tvalidation_1-auc:0.867914\tvalidation_1-error:0.089136\n",
"[37]\tvalidation_0-auc:0.868079\tvalidation_0-error:0.085324\tvalidation_1-auc:0.868238\tvalidation_1-error:0.089071\n",
"[38]\tvalidation_0-auc:0.867856\tvalidation_0-error:0.085608\tvalidation_1-auc:0.86799\tvalidation_1-error:0.089529\n",
"[39]\tvalidation_0-auc:0.867798\tvalidation_0-error:0.085651\tvalidation_1-auc:0.86819\tvalidation_1-error:0.089529\n",
"[40]\tvalidation_0-auc:0.868403\tvalidation_0-error:0.085564\tvalidation_1-auc:0.868699\tvalidation_1-error:0.089463\n",
"[41]\tvalidation_0-auc:0.86849\tvalidation_0-error:0.085215\tvalidation_1-auc:0.868739\tvalidation_1-error:0.089136\n",
"[42]\tvalidation_0-auc:0.868427\tvalidation_0-error:0.085542\tvalidation_1-auc:0.868531\tvalidation_1-error:0.089529\n",
"[43]\tvalidation_0-auc:0.86876\tvalidation_0-error:0.085542\tvalidation_1-auc:0.86891\tvalidation_1-error:0.089463\n",
"[44]\tvalidation_0-auc:0.869809\tvalidation_0-error:0.085608\tvalidation_1-auc:0.870056\tvalidation_1-error:0.089463\n",
"[45]\tvalidation_0-auc:0.869507\tvalidation_0-error:0.085499\tvalidation_1-auc:0.869661\tvalidation_1-error:0.089463\n",
"[46]\tvalidation_0-auc:0.869391\tvalidation_0-error:0.085608\tvalidation_1-auc:0.869682\tvalidation_1-error:0.08966\n",
"[47]\tvalidation_0-auc:0.869438\tvalidation_0-error:0.08576\tvalidation_1-auc:0.869752\tvalidation_1-error:0.089594\n",
"[48]\tvalidation_0-auc:0.869761\tvalidation_0-error:0.085848\tvalidation_1-auc:0.870027\tvalidation_1-error:0.089725\n",
"[49]\tvalidation_0-auc:0.870033\tvalidation_0-error:0.08552\tvalidation_1-auc:0.870413\tvalidation_1-error:0.089463\n",
"[50]\tvalidation_0-auc:0.870215\tvalidation_0-error:0.085651\tvalidation_1-auc:0.870418\tvalidation_1-error:0.089529\n",
"[51]\tvalidation_0-auc:0.869963\tvalidation_0-error:0.085433\tvalidation_1-auc:0.870094\tvalidation_1-error:0.089332\n",
"[52]\tvalidation_0-auc:0.870333\tvalidation_0-error:0.085084\tvalidation_1-auc:0.870355\tvalidation_1-error:0.08894\n",
"[53]\tvalidation_0-auc:0.870283\tvalidation_0-error:0.085128\tvalidation_1-auc:0.870476\tvalidation_1-error:0.08894\n",
"[54]\tvalidation_0-auc:0.870032\tvalidation_0-error:0.085433\tvalidation_1-auc:0.870275\tvalidation_1-error:0.089398\n",
"[55]\tvalidation_0-auc:0.870189\tvalidation_0-error:0.085455\tvalidation_1-auc:0.870531\tvalidation_1-error:0.089463\n",
"[56]\tvalidation_0-auc:0.869892\tvalidation_0-error:0.085608\tvalidation_1-auc:0.870177\tvalidation_1-error:0.089529\n",
"[57]\tvalidation_0-auc:0.869509\tvalidation_0-error:0.085739\tvalidation_1-auc:0.86973\tvalidation_1-error:0.089594\n",
"[58]\tvalidation_0-auc:0.869183\tvalidation_0-error:0.085979\tvalidation_1-auc:0.869321\tvalidation_1-error:0.089529\n",
"[59]\tvalidation_0-auc:0.869056\tvalidation_0-error:0.085564\tvalidation_1-auc:0.869175\tvalidation_1-error:0.089398\n",
"[60]\tvalidation_0-auc:0.86907\tvalidation_0-error:0.085564\tvalidation_1-auc:0.869009\tvalidation_1-error:0.089398\n",
"[61]\tvalidation_0-auc:0.868867\tvalidation_0-error:0.08576\tvalidation_1-auc:0.868804\tvalidation_1-error:0.089529\n",
"[62]\tvalidation_0-auc:0.868679\tvalidation_0-error:0.085957\tvalidation_1-auc:0.86868\tvalidation_1-error:0.089594\n",
"[63]\tvalidation_0-auc:0.868765\tvalidation_0-error:0.086371\tvalidation_1-auc:0.868526\tvalidation_1-error:0.089856\n",
"[64]\tvalidation_0-auc:0.868661\tvalidation_0-error:0.085717\tvalidation_1-auc:0.868409\tvalidation_1-error:0.089529\n",
"[65]\tvalidation_0-auc:0.86865\tvalidation_0-error:0.085957\tvalidation_1-auc:0.868427\tvalidation_1-error:0.089529\n",
"[66]\tvalidation_0-auc:0.868777\tvalidation_0-error:0.08552\tvalidation_1-auc:0.868408\tvalidation_1-error:0.089398\n",
"[67]\tvalidation_0-auc:0.86935\tvalidation_0-error:0.085717\tvalidation_1-auc:0.868996\tvalidation_1-error:0.089529\n",
"[68]\tvalidation_0-auc:0.869549\tvalidation_0-error:0.086262\tvalidation_1-auc:0.868994\tvalidation_1-error:0.089921\n",
"[69]\tvalidation_0-auc:0.870164\tvalidation_0-error:0.086502\tvalidation_1-auc:0.86985\tvalidation_1-error:0.089921\n",
"[70]\tvalidation_0-auc:0.870224\tvalidation_0-error:0.085957\tvalidation_1-auc:0.869928\tvalidation_1-error:0.089594\n",
"[71]\tvalidation_0-auc:0.869967\tvalidation_0-error:0.085957\tvalidation_1-auc:0.869716\tvalidation_1-error:0.089594\n",
"[72]\tvalidation_0-auc:0.869862\tvalidation_0-error:0.086349\tvalidation_1-auc:0.869547\tvalidation_1-error:0.089921\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[73]\tvalidation_0-auc:0.870112\tvalidation_0-error:0.086437\tvalidation_1-auc:0.869796\tvalidation_1-error:0.089921\n",
"[74]\tvalidation_0-auc:0.870304\tvalidation_0-error:0.086524\tvalidation_1-auc:0.869977\tvalidation_1-error:0.089856\n",
"[75]\tvalidation_0-auc:0.870106\tvalidation_0-error:0.086349\tvalidation_1-auc:0.869769\tvalidation_1-error:0.089921\n",
"[76]\tvalidation_0-auc:0.870013\tvalidation_0-error:0.085979\tvalidation_1-auc:0.869659\tvalidation_1-error:0.089529\n",
"[77]\tvalidation_0-auc:0.869817\tvalidation_0-error:0.086306\tvalidation_1-auc:0.869416\tvalidation_1-error:0.089921\n",
"[78]\tvalidation_0-auc:0.869962\tvalidation_0-error:0.085695\tvalidation_1-auc:0.869603\tvalidation_1-error:0.089398\n",
"[79]\tvalidation_0-auc:0.870129\tvalidation_0-error:0.085455\tvalidation_1-auc:0.869848\tvalidation_1-error:0.089202\n",
"[80]\tvalidation_0-auc:0.869945\tvalidation_0-error:0.085477\tvalidation_1-auc:0.869671\tvalidation_1-error:0.089267\n",
"[81]\tvalidation_0-auc:0.870049\tvalidation_0-error:0.085586\tvalidation_1-auc:0.869803\tvalidation_1-error:0.089398\n",
"[82]\tvalidation_0-auc:0.870052\tvalidation_0-error:0.085717\tvalidation_1-auc:0.869887\tvalidation_1-error:0.089463\n",
"[83]\tvalidation_0-auc:0.87035\tvalidation_0-error:0.08576\tvalidation_1-auc:0.870185\tvalidation_1-error:0.089529\n",
"[84]\tvalidation_0-auc:0.870545\tvalidation_0-error:0.086066\tvalidation_1-auc:0.870365\tvalidation_1-error:0.089594\n",
"[85]\tvalidation_0-auc:0.870496\tvalidation_0-error:0.085673\tvalidation_1-auc:0.870276\tvalidation_1-error:0.089398\n",
"[86]\tvalidation_0-auc:0.870563\tvalidation_0-error:0.085542\tvalidation_1-auc:0.870213\tvalidation_1-error:0.089332\n",
"[87]\tvalidation_0-auc:0.870483\tvalidation_0-error:0.085957\tvalidation_1-auc:0.870176\tvalidation_1-error:0.089463\n",
"[88]\tvalidation_0-auc:0.870437\tvalidation_0-error:0.08624\tvalidation_1-auc:0.870178\tvalidation_1-error:0.08966\n",
"[89]\tvalidation_0-auc:0.871024\tvalidation_0-error:0.086371\tvalidation_1-auc:0.870727\tvalidation_1-error:0.089725\n",
"[90]\tvalidation_0-auc:0.871107\tvalidation_0-error:0.085935\tvalidation_1-auc:0.870856\tvalidation_1-error:0.089267\n",
"[91]\tvalidation_0-auc:0.871102\tvalidation_0-error:0.086175\tvalidation_1-auc:0.870947\tvalidation_1-error:0.089529\n",
"[92]\tvalidation_0-auc:0.871265\tvalidation_0-error:0.08552\tvalidation_1-auc:0.871157\tvalidation_1-error:0.089202\n",
"[93]\tvalidation_0-auc:0.871577\tvalidation_0-error:0.085848\tvalidation_1-auc:0.871507\tvalidation_1-error:0.089267\n",
"[94]\tvalidation_0-auc:0.871623\tvalidation_0-error:0.085499\tvalidation_1-auc:0.871392\tvalidation_1-error:0.089202\n",
"[95]\tvalidation_0-auc:0.871802\tvalidation_0-error:0.085673\tvalidation_1-auc:0.87163\tvalidation_1-error:0.089202\n",
"[96]\tvalidation_0-auc:0.871816\tvalidation_0-error:0.085324\tvalidation_1-auc:0.871624\tvalidation_1-error:0.089136\n",
"[97]\tvalidation_0-auc:0.871802\tvalidation_0-error:0.085411\tvalidation_1-auc:0.87166\tvalidation_1-error:0.089136\n",
"[98]\tvalidation_0-auc:0.871747\tvalidation_0-error:0.085411\tvalidation_1-auc:0.871759\tvalidation_1-error:0.089136\n",
"[99]\tvalidation_0-auc:0.871832\tvalidation_0-error:0.085433\tvalidation_1-auc:0.871987\tvalidation_1-error:0.089136\n",
"Wall time: 3.24 s\n"
]
},
{
"data": {
"text/plain": [
"XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
" colsample_bynode=1, colsample_bytree=0.4, gamma=1,\n",
" learning_rate=0.01, max_delta_step=0, max_depth=3,\n",
" min_child_weight=1, missing=None, n_estimators=100, n_jobs=1,\n",
" nthread=None, objective='binary:logistic', random_state=0,\n",
" reg_alpha=0.3, reg_lambda=1, scale_pos_weight=1, seed=None,\n",
" silent=False, subsample=0.8, verbosity=1)"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"eval_set = [(Xlr, ylr), (Xtestlr, ytestlr)]\n",
"eval_metric = [\"auc\", \"error\"]\n",
"%time model.fit(Xlr, ylr, eval_metric=eval_metric, eval_set=eval_set, verbose=True)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}