--- a +++ b/LUNG_CANCER_logistic_regression.ipynb @@ -0,0 +1,2372 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lung_cancer: Logistic regression\n", + "---\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a id='imporp'></a>\n", + "## Importing Packages\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "## Basic packages\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "\n", + "## Graphing packages\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "plt.style.use('fivethirtyeight')\n", + "\n", + "## Scikit learn and Statsmodel packages\n", + "from sklearn.linear_model import LogisticRegression, LinearRegression\n", + "import statsmodels.api as sm\n", + "from sklearn.metrics import confusion_matrix\n", + "## Operating system dependent functionality\n", + "import os\n", + "import statsmodels.api as st \n", + "#from pandas.stats.api import ols\n", + "## Lines of code needed to make sure graph(s) appear in notebook, and check versions of packages\n", + "%matplotlib inline\n", + "#%load_ext watermark\n", + "#%config InlineBackend.figure_format = 'retina'\n", + "#%watermark -v -d -a 'Delta Analytics' -p scikit-learn,matplotlib,numpy,pandas" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a id='rds'></a>\n", + "## Reading the dataset\n", + "---\n", + "we are using Lung_Cancer dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>patient_id</th>\n", + " <th>age</th>\n", + " <th>gender</th>\n", + " <th>air_pollution</th>\n", + " <th>alcohol_use</th>\n", + " <th>dust_allergy</th>\n", + " <th>occupational_hazards</th>\n", + " <th>genetic_risk</th>\n", + " <th>chronic_lung_disease</th>\n", + " <th>balanced_diet</th>\n", + " <th>...</th>\n", + " <th>fatigue</th>\n", + " <th>weight_loss</th>\n", + " <th>shortness_of_breath</th>\n", + " <th>wheezing</th>\n", + " <th>swallowing_difficulty</th>\n", + " <th>clubbing_of_finger_nails</th>\n", + " <th>frequent_cold</th>\n", + " <th>dry_cough</th>\n", + " <th>snoring</th>\n", + " <th>level</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>P1</td>\n", + " <td>33</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>5</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>Low</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>P10</td>\n", + " <td>17</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>7</td>\n", + " <td>8</td>\n", + " <td>6</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>7</td>\n", + " <td>2</td>\n", + " <td>Medium</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>P107</td>\n", + " <td>44</td>\n", + " <td>1</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>...</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>7</td>\n", + " <td>8</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>High</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>3 rows × 25 columns</p>\n", + "</div>" + ], + "text/plain": [ + " patient_id age gender air_pollution alcohol_use dust_allergy \\\n", + "0 P1 33 1 2 4 5 \n", + "1 P10 17 1 3 1 5 \n", + "2 P107 44 1 6 7 7 \n", + "\n", + " occupational_hazards genetic_risk chronic_lung_disease balanced_diet \\\n", + "0 4 3 2 2 \n", + "1 3 4 2 2 \n", + "2 7 7 6 7 \n", + "\n", + " ... fatigue weight_loss shortness_of_breath wheezing \\\n", + "0 ... 3 4 2 2 \n", + "1 ... 1 3 7 8 \n", + "2 ... 5 3 2 7 \n", + "\n", + " swallowing_difficulty clubbing_of_finger_nails frequent_cold dry_cough \\\n", + "0 3 1 2 3 \n", + "1 6 2 1 7 \n", + "2 8 2 4 5 \n", + "\n", + " snoring level \n", + "0 4 Low \n", + "1 2 Medium \n", + "2 3 High \n", + "\n", + "[3 rows x 25 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LUNG_CANCER_filepath = os.path.join('cancer_patient.csv')\n", + "LUNG_CANCER = pd.read_csv(LUNG_CANCER_filepath)\n", + "LUNG_CANCER.head(3)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>patient_id</th>\n", + " <th>age</th>\n", + " <th>gender</th>\n", + " <th>air_pollution</th>\n", + " <th>alcohol_use</th>\n", + " <th>dust_allergy</th>\n", + " <th>occupational_hazards</th>\n", + " <th>genetic_risk</th>\n", + " <th>chronic_lung_disease</th>\n", + " <th>balanced_diet</th>\n", + " <th>...</th>\n", + " <th>fatigue</th>\n", + " <th>weight_loss</th>\n", + " <th>shortness_of_breath</th>\n", + " <th>wheezing</th>\n", + " <th>swallowing_difficulty</th>\n", + " <th>clubbing_of_finger_nails</th>\n", + " <th>frequent_cold</th>\n", + " <th>dry_cough</th>\n", + " <th>snoring</th>\n", + " <th>level</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>997</th>\n", + " <td>P997</td>\n", + " <td>25</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>5</td>\n", + " <td>6</td>\n", + " <td>5</td>\n", + " <td>5</td>\n", + " <td>4</td>\n", + " <td>6</td>\n", + " <td>...</td>\n", + " <td>8</td>\n", + " <td>7</td>\n", + " <td>9</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>2</td>\n", + " <td>High</td>\n", + " </tr>\n", + " <tr>\n", + " <th>998</th>\n", + " <td>P998</td>\n", + " <td>18</td>\n", + " <td>2</td>\n", + " <td>6</td>\n", + " <td>8</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>High</td>\n", + " </tr>\n", + " <tr>\n", + " <th>999</th>\n", + " <td>P999</td>\n", + " <td>47</td>\n", + " <td>1</td>\n", + " <td>6</td>\n", + " <td>5</td>\n", + " <td>6</td>\n", + " <td>5</td>\n", + " <td>5</td>\n", + " <td>4</td>\n", + " <td>6</td>\n", + " <td>...</td>\n", + " <td>8</td>\n", + " <td>7</td>\n", + " <td>9</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>2</td>\n", + " <td>High</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>3 rows × 25 columns</p>\n", + "</div>" + ], + "text/plain": [ + " patient_id age gender air_pollution alcohol_use dust_allergy \\\n", + "997 P997 25 2 4 5 6 \n", + "998 P998 18 2 6 8 7 \n", + "999 P999 47 1 6 5 6 \n", + "\n", + " occupational_hazards genetic_risk chronic_lung_disease balanced_diet \\\n", + "997 5 5 4 6 \n", + "998 7 7 6 7 \n", + "999 5 5 4 6 \n", + "\n", + " ... fatigue weight_loss shortness_of_breath wheezing \\\n", + "997 ... 8 7 9 2 \n", + "998 ... 3 2 4 1 \n", + "999 ... 8 7 9 2 \n", + "\n", + " swallowing_difficulty clubbing_of_finger_nails frequent_cold \\\n", + "997 1 4 6 \n", + "998 4 2 4 \n", + "999 1 4 6 \n", + "\n", + " dry_cough snoring level \n", + "997 7 2 High \n", + "998 2 3 High \n", + "999 7 2 High \n", + "\n", + "[3 rows x 25 columns]" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LUNG_CANCER.tail(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a id='msvl'></a>\n", + "### Missing Values\n", + "---\n", + "1. we will drop the missing values if there is one" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "patient_id 0\n", + "age 0\n", + "gender 0\n", + "air_pollution 0\n", + "alcohol_use 0\n", + "dust_allergy 0\n", + "occupational_hazards 0\n", + "genetic_risk 0\n", + "chronic_lung_disease 0\n", + "balanced_diet 0\n", + "obesity 0\n", + "smoking 0\n", + "passive_smoker 0\n", + "chest_pain 0\n", + "coughing_of_blood 0\n", + "fatigue 0\n", + "weight_loss 0\n", + "shortness_of_breath 0\n", + "wheezing 0\n", + "swallowing_difficulty 0\n", + "clubbing_of_finger_nails 0\n", + "frequent_cold 0\n", + "dry_cough 0\n", + "snoring 0\n", + "level 0\n", + "dtype: int64" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LUNG_CANCER.isnull().sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "LUNG_CANCER.dropna(inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "patient_id 0\n", + "age 0\n", + "gender 0\n", + "air_pollution 0\n", + "alcohol_use 0\n", + "dust_allergy 0\n", + "occupational_hazards 0\n", + "genetic_risk 0\n", + "chronic_lung_disease 0\n", + "balanced_diet 0\n", + "obesity 0\n", + "smoking 0\n", + "passive_smoker 0\n", + "chest_pain 0\n", + "coughing_of_blood 0\n", + "fatigue 0\n", + "weight_loss 0\n", + "shortness_of_breath 0\n", + "wheezing 0\n", + "swallowing_difficulty 0\n", + "clubbing_of_finger_nails 0\n", + "frequent_cold 0\n", + "dry_cough 0\n", + "snoring 0\n", + "level 0\n", + "dtype: int64" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LUNG_CANCER.isnull().sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a id='implementation'></a>\n", + "## Implementation of Logistic Regression\n", + "---\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a id='LEVEL'></a>\n", + "### Level: Low, Medium, High\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>age</th>\n", + " <th>gender</th>\n", + " <th>air_pollution</th>\n", + " <th>alcohol_use</th>\n", + " <th>dust_allergy</th>\n", + " <th>occupational_hazards</th>\n", + " <th>genetic_risk</th>\n", + " <th>chronic_lung_disease</th>\n", + " <th>balanced_diet</th>\n", + " <th>obesity</th>\n", + " <th>...</th>\n", + " <th>coughing_of_blood</th>\n", + " <th>fatigue</th>\n", + " <th>weight_loss</th>\n", + " <th>shortness_of_breath</th>\n", + " <th>wheezing</th>\n", + " <th>swallowing_difficulty</th>\n", + " <th>clubbing_of_finger_nails</th>\n", + " <th>frequent_cold</th>\n", + " <th>dry_cough</th>\n", + " <th>snoring</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>count</th>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.0000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>...</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " <td>1000.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mean</th>\n", + " <td>37.174000</td>\n", + " <td>1.402000</td>\n", + " <td>3.8400</td>\n", + " <td>4.563000</td>\n", + " <td>5.165000</td>\n", + " <td>4.840000</td>\n", + " <td>4.580000</td>\n", + " <td>4.380000</td>\n", + " <td>4.491000</td>\n", + " <td>4.465000</td>\n", + " <td>...</td>\n", + " <td>4.859000</td>\n", + " <td>3.856000</td>\n", + " <td>3.855000</td>\n", + " <td>4.240000</td>\n", + " <td>3.777000</td>\n", + " <td>3.746000</td>\n", + " <td>3.923000</td>\n", + " <td>3.536000</td>\n", + " <td>3.853000</td>\n", + " <td>2.926000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>std</th>\n", + " <td>12.005493</td>\n", + " <td>0.490547</td>\n", + " <td>2.0304</td>\n", + " <td>2.620477</td>\n", + " <td>1.980833</td>\n", + " <td>2.107805</td>\n", + " <td>2.126999</td>\n", + " <td>1.848518</td>\n", + " <td>2.135528</td>\n", + " <td>2.124921</td>\n", + " <td>...</td>\n", + " <td>2.427965</td>\n", + " <td>2.244616</td>\n", + " <td>2.206546</td>\n", + " <td>2.285087</td>\n", + " <td>2.041921</td>\n", + " <td>2.270383</td>\n", + " <td>2.388048</td>\n", + " <td>1.832502</td>\n", + " <td>2.039007</td>\n", + " <td>1.474686</td>\n", + " </tr>\n", + " <tr>\n", + " <th>min</th>\n", + " <td>14.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.0000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>...</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>25%</th>\n", + " <td>27.750000</td>\n", + " <td>1.000000</td>\n", + " <td>2.0000</td>\n", + " <td>2.000000</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " <td>2.000000</td>\n", + " <td>3.000000</td>\n", + " <td>2.000000</td>\n", + " <td>3.000000</td>\n", + " <td>...</td>\n", + " <td>3.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " <td>2.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>50%</th>\n", + " <td>36.000000</td>\n", + " <td>1.000000</td>\n", + " <td>3.0000</td>\n", + " <td>5.000000</td>\n", + " <td>6.000000</td>\n", + " <td>5.000000</td>\n", + " <td>5.000000</td>\n", + " <td>4.000000</td>\n", + " <td>4.000000</td>\n", + " <td>4.000000</td>\n", + " <td>...</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " <td>3.000000</td>\n", + " <td>4.000000</td>\n", + " <td>4.000000</td>\n", + " <td>4.000000</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>75%</th>\n", + " <td>45.000000</td>\n", + " <td>2.000000</td>\n", + " <td>6.0000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>6.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>...</td>\n", + " <td>7.000000</td>\n", + " <td>5.000000</td>\n", + " <td>6.000000</td>\n", + " <td>6.000000</td>\n", + " <td>5.000000</td>\n", + " <td>5.000000</td>\n", + " <td>5.000000</td>\n", + " <td>5.000000</td>\n", + " <td>6.000000</td>\n", + " <td>4.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>max</th>\n", + " <td>73.000000</td>\n", + " <td>2.000000</td>\n", + " <td>8.0000</td>\n", + " <td>8.000000</td>\n", + " <td>8.000000</td>\n", + " <td>8.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>...</td>\n", + " <td>9.000000</td>\n", + " <td>9.000000</td>\n", + " <td>8.000000</td>\n", + " <td>9.000000</td>\n", + " <td>8.000000</td>\n", + " <td>8.000000</td>\n", + " <td>9.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " <td>7.000000</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>8 rows × 23 columns</p>\n", + "</div>" + ], + "text/plain": [ + " age gender air_pollution alcohol_use dust_allergy \\\n", + "count 1000.000000 1000.000000 1000.0000 1000.000000 1000.000000 \n", + "mean 37.174000 1.402000 3.8400 4.563000 5.165000 \n", + "std 12.005493 0.490547 2.0304 2.620477 1.980833 \n", + "min 14.000000 1.000000 1.0000 1.000000 1.000000 \n", + "25% 27.750000 1.000000 2.0000 2.000000 4.000000 \n", + "50% 36.000000 1.000000 3.0000 5.000000 6.000000 \n", + "75% 45.000000 2.000000 6.0000 7.000000 7.000000 \n", + "max 73.000000 2.000000 8.0000 8.000000 8.000000 \n", + "\n", + " occupational_hazards genetic_risk chronic_lung_disease \\\n", + "count 1000.000000 1000.000000 1000.000000 \n", + "mean 4.840000 4.580000 4.380000 \n", + "std 2.107805 2.126999 1.848518 \n", + "min 1.000000 1.000000 1.000000 \n", + "25% 3.000000 2.000000 3.000000 \n", + "50% 5.000000 5.000000 4.000000 \n", + "75% 7.000000 7.000000 6.000000 \n", + "max 8.000000 7.000000 7.000000 \n", + "\n", + " balanced_diet obesity ... coughing_of_blood \\\n", + "count 1000.000000 1000.000000 ... 1000.000000 \n", + "mean 4.491000 4.465000 ... 4.859000 \n", + "std 2.135528 2.124921 ... 2.427965 \n", + "min 1.000000 1.000000 ... 1.000000 \n", + "25% 2.000000 3.000000 ... 3.000000 \n", + "50% 4.000000 4.000000 ... 4.000000 \n", + "75% 7.000000 7.000000 ... 7.000000 \n", + "max 7.000000 7.000000 ... 9.000000 \n", + "\n", + " fatigue weight_loss shortness_of_breath wheezing \\\n", + "count 1000.000000 1000.000000 1000.000000 1000.000000 \n", + "mean 3.856000 3.855000 4.240000 3.777000 \n", + "std 2.244616 2.206546 2.285087 2.041921 \n", + "min 1.000000 1.000000 1.000000 1.000000 \n", + "25% 2.000000 2.000000 2.000000 2.000000 \n", + "50% 3.000000 3.000000 4.000000 4.000000 \n", + "75% 5.000000 6.000000 6.000000 5.000000 \n", + "max 9.000000 8.000000 9.000000 8.000000 \n", + "\n", + " swallowing_difficulty clubbing_of_finger_nails frequent_cold \\\n", + "count 1000.000000 1000.000000 1000.000000 \n", + "mean 3.746000 3.923000 3.536000 \n", + "std 2.270383 2.388048 1.832502 \n", + "min 1.000000 1.000000 1.000000 \n", + "25% 2.000000 2.000000 2.000000 \n", + "50% 4.000000 4.000000 3.000000 \n", + "75% 5.000000 5.000000 5.000000 \n", + "max 8.000000 9.000000 7.000000 \n", + "\n", + " dry_cough snoring \n", + "count 1000.000000 1000.000000 \n", + "mean 3.853000 2.926000 \n", + "std 2.039007 1.474686 \n", + "min 1.000000 1.000000 \n", + "25% 2.000000 2.000000 \n", + "50% 4.000000 3.000000 \n", + "75% 6.000000 4.000000 \n", + "max 7.000000 7.000000 \n", + "\n", + "[8 rows x 23 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "## Describe our dataset\n", + "LUNG_CANCER.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<class 'pandas.core.frame.DataFrame'>\n", + "Int64Index: 1000 entries, 0 to 999\n", + "Data columns (total 25 columns):\n", + "patient_id 1000 non-null object\n", + "age 1000 non-null int64\n", + "gender 1000 non-null int64\n", + "air_pollution 1000 non-null int64\n", + "alcohol_use 1000 non-null int64\n", + "dust_allergy 1000 non-null int64\n", + "occupational_hazards 1000 non-null int64\n", + "genetic_risk 1000 non-null int64\n", + "chronic_lung_disease 1000 non-null int64\n", + "balanced_diet 1000 non-null int64\n", + "obesity 1000 non-null int64\n", + "smoking 1000 non-null int64\n", + "passive_smoker 1000 non-null int64\n", + "chest_pain 1000 non-null int64\n", + "coughing_of_blood 1000 non-null int64\n", + "fatigue 1000 non-null int64\n", + "weight_loss 1000 non-null int64\n", + "shortness_of_breath 1000 non-null int64\n", + "wheezing 1000 non-null int64\n", + "swallowing_difficulty 1000 non-null int64\n", + "clubbing_of_finger_nails 1000 non-null int64\n", + "frequent_cold 1000 non-null int64\n", + "dry_cough 1000 non-null int64\n", + "snoring 1000 non-null int64\n", + "level 1000 non-null object\n", + "dtypes: int64(23), object(2)\n", + "memory usage: 203.1+ KB\n" + ] + } + ], + "source": [ + "LUNG_CANCER.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### here we have a categorical Column in our dataset which is Level" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>patient_id</th>\n", + " <th>age</th>\n", + " <th>gender</th>\n", + " <th>air_pollution</th>\n", + " <th>alcohol_use</th>\n", + " <th>dust_allergy</th>\n", + " <th>occupational_hazards</th>\n", + " <th>genetic_risk</th>\n", + " <th>chronic_lung_disease</th>\n", + " <th>balanced_diet</th>\n", + " <th>...</th>\n", + " <th>fatigue</th>\n", + " <th>weight_loss</th>\n", + " <th>shortness_of_breath</th>\n", + " <th>wheezing</th>\n", + " <th>swallowing_difficulty</th>\n", + " <th>clubbing_of_finger_nails</th>\n", + " <th>frequent_cold</th>\n", + " <th>dry_cough</th>\n", + " <th>snoring</th>\n", + " <th>level</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>P1</td>\n", + " <td>33</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>5</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>P10</td>\n", + " <td>17</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>7</td>\n", + " <td>8</td>\n", + " <td>6</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>7</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>P107</td>\n", + " <td>44</td>\n", + " <td>1</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>...</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>7</td>\n", + " <td>8</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>P189</td>\n", + " <td>39</td>\n", + " <td>2</td>\n", + " <td>6</td>\n", + " <td>8</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>4 rows × 25 columns</p>\n", + "</div>" + ], + "text/plain": [ + " patient_id age gender air_pollution alcohol_use dust_allergy \\\n", + "0 P1 33 1 2 4 5 \n", + "1 P10 17 1 3 1 5 \n", + "2 P107 44 1 6 7 7 \n", + "3 P189 39 2 6 8 7 \n", + "\n", + " occupational_hazards genetic_risk chronic_lung_disease balanced_diet \\\n", + "0 4 3 2 2 \n", + "1 3 4 2 2 \n", + "2 7 7 6 7 \n", + "3 7 7 6 7 \n", + "\n", + " ... fatigue weight_loss shortness_of_breath wheezing \\\n", + "0 ... 3 4 2 2 \n", + "1 ... 1 3 7 8 \n", + "2 ... 5 3 2 7 \n", + "3 ... 3 2 4 1 \n", + "\n", + " swallowing_difficulty clubbing_of_finger_nails frequent_cold dry_cough \\\n", + "0 3 1 2 3 \n", + "1 6 2 1 7 \n", + "2 8 2 4 5 \n", + "3 4 2 4 2 \n", + "\n", + " snoring level \n", + "0 4 1 \n", + "1 2 2 \n", + "2 3 3 \n", + "3 3 3 \n", + "\n", + "[4 rows x 25 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def data_cleaning(data):\n", + " LUNG_CANCER[\"age\"]=data[\"age\"].fillna(LUNG_CANCER[\"age\"].median())\n", + " \n", + " \n", + " LUNG_CANCER.loc[data[\"level\"]==\"Low\",\"level\"]=1\n", + " LUNG_CANCER.loc[data[\"level\"]==\"Medium\",\"level\"]=2\n", + " LUNG_CANCER.loc[data[\"level\"]==\"High\",\"level\"]=3\n", + " \n", + " return data\n", + "\n", + "LUNG_CANCER=data_cleaning(LUNG_CANCER)\n", + "LUNG_CANCER.head(4)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>patient_id</th>\n", + " <th>age</th>\n", + " <th>gender</th>\n", + " <th>air_pollution</th>\n", + " <th>alcohol_use</th>\n", + " <th>dust_allergy</th>\n", + " <th>occupational_hazards</th>\n", + " <th>genetic_risk</th>\n", + " <th>chronic_lung_disease</th>\n", + " <th>balanced_diet</th>\n", + " <th>...</th>\n", + " <th>weight_loss</th>\n", + " <th>shortness_of_breath</th>\n", + " <th>wheezing</th>\n", + " <th>swallowing_difficulty</th>\n", + " <th>clubbing_of_finger_nails</th>\n", + " <th>frequent_cold</th>\n", + " <th>dry_cough</th>\n", + " <th>snoring</th>\n", + " <th>level</th>\n", + " <th>intercept</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>P1</td>\n", + " <td>33</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>5</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>P10</td>\n", + " <td>17</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>7</td>\n", + " <td>8</td>\n", + " <td>6</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>7</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>P107</td>\n", + " <td>44</td>\n", + " <td>1</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>7</td>\n", + " <td>8</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>P189</td>\n", + " <td>39</td>\n", + " <td>2</td>\n", + " <td>6</td>\n", + " <td>8</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>7</td>\n", + " <td>6</td>\n", + " <td>7</td>\n", + " <td>...</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>4 rows × 26 columns</p>\n", + "</div>" + ], + "text/plain": [ + " patient_id age gender air_pollution alcohol_use dust_allergy \\\n", + "0 P1 33 1 2 4 5 \n", + "1 P10 17 1 3 1 5 \n", + "2 P107 44 1 6 7 7 \n", + "3 P189 39 2 6 8 7 \n", + "\n", + " occupational_hazards genetic_risk chronic_lung_disease balanced_diet \\\n", + "0 4 3 2 2 \n", + "1 3 4 2 2 \n", + "2 7 7 6 7 \n", + "3 7 7 6 7 \n", + "\n", + " ... weight_loss shortness_of_breath wheezing \\\n", + "0 ... 4 2 2 \n", + "1 ... 3 7 8 \n", + "2 ... 3 2 7 \n", + "3 ... 2 4 1 \n", + "\n", + " swallowing_difficulty clubbing_of_finger_nails frequent_cold dry_cough \\\n", + "0 3 1 2 3 \n", + "1 6 2 1 7 \n", + "2 8 2 4 5 \n", + "3 4 2 4 2 \n", + "\n", + " snoring level intercept \n", + "0 4 1 1.0 \n", + "1 2 2 1.0 \n", + "2 3 3 1.0 \n", + "3 3 3 1.0 \n", + "\n", + "[4 rows x 26 columns]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "## now lets add the intercept\n", + "\n", + "LUNG_CANCER['intercept'] = 1.0\n", + "\n", + "## we have a dataset that is ready for analysis\n", + "LUNG_CANCER.head(4)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1000, 17) (1000,)\n", + "(700, 17) (700,)\n", + "(300, 17) (300,)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Lina\\Anaconda3\\lib\\site-packages\\sklearn\\cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n", + " \"This module will be removed in 0.20.\", DeprecationWarning)\n" + ] + } + ], + "source": [ + "'''Define y and X'''\n", + "y = LUNG_CANCER['level'] \n", + "columns_ = LUNG_CANCER.columns.tolist()\n", + "exclude_col = ['level','patient_id','alcohol_use','dust_allergy','occupational_hazards','balanced_diet','obesity','snoring','frequent_cold']\n", + "X = LUNG_CANCER[[i for i in columns_ if i not in exclude_col]]\n", + "X = st.add_constant(X, prepend = False) \n", + "print (X.shape, y.shape)\n", + "\n", + "'''Split the data'''\n", + "from sklearn.cross_validation import train_test_split\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=40)\n", + "\n", + "print (X_train.shape, y_train.shape)\n", + "print (X_test.shape, y_test.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['patient_id', 'age', 'gender', 'air_pollution', 'alcohol_use',\n", + " 'dust_allergy', 'occupational_hazards', 'genetic_risk',\n", + " 'chronic_lung_disease', 'balanced_diet', 'obesity', 'smoking',\n", + " 'passive_smoker', 'chest_pain', 'coughing_of_blood', 'fatigue',\n", + " 'weight_loss', 'shortness_of_breath', 'wheezing',\n", + " 'swallowing_difficulty', 'clubbing_of_finger_nails', 'frequent_cold',\n", + " 'dry_cough', 'snoring', 'level', 'intercept'],\n", + " dtype='object')" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LUNG_CANCER.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-3.17685199e-02, 1.47464418e-01, 1.91315555e-01,\n", + " -1.07116609e+00, 3.29164839e-02, 6.86628736e-02,\n", + " -1.37516216e+00, 7.06189546e-01, -1.28525724e+00,\n", + " -1.53738658e+00, 1.03157677e-01, -3.54326295e-01,\n", + " -9.39638735e-01, -7.18157437e-01, -1.24079736e+00,\n", + " -4.89756753e-01, 2.93152389e-06],\n", + " [ 3.66812895e-02, -2.78637530e-01, -5.08674811e-01,\n", + " 6.64688514e-01, -4.37367788e-01, -2.12402217e-01,\n", + " 1.47896467e-01, -7.19789255e-01, 2.09123966e-02,\n", + " 3.45001892e-01, -5.25706903e-01, -2.19421558e-01,\n", + " 7.03518991e-01, 1.15699818e-02, 6.00934480e-01,\n", + " 2.57988306e-02, 1.51172633e-08],\n", + " [-4.91276964e-03, 1.31173112e-01, 3.17359256e-01,\n", + " 4.06477580e-01, 4.04451304e-01, 1.43739344e-01,\n", + " 1.22726570e+00, 1.35997095e-02, 1.26434485e+00,\n", + " 1.19238469e+00, 4.22549226e-01, 5.73747853e-01,\n", + " 2.36119744e-01, 7.06587455e-01, 6.39862879e-01,\n", + " 4.63957923e-01, -2.94664117e-06]])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "## Set up the regression\n", + "\n", + "mul_lr = LogisticRegression(multi_class='multinomial',solver ='newton-cg').fit(X_train,y_train)\n", + "\n", + "## lets get the results\n", + "mul_lr.intercept_\n", + "mul_lr.coef_" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: Maximum number of iterations has been exceeded.\n", + " Current function value: 0.097334\n", + " Iterations: 35\n", + " Function evaluations: 38\n", + " Gradient evaluations: 38\n", + " MNLogit Regression Results \n", + "==============================================================================\n", + "Dep. Variable: level No. Observations: 700\n", + "Model: MNLogit Df Residuals: 666\n", + "Method: MLE Df Model: 32\n", + "Date: Tue, 04 Dec 2018 Pseudo R-squ.: 0.9112\n", + "Time: 20:33:12 Log-Likelihood: -68.134\n", + "converged: False LL-Null: -767.63\n", + " LLR p-value: 5.963e-274\n", + "============================================================================================\n", + " level=2 coef std err z P>|z| [0.025 0.975]\n", + "--------------------------------------------------------------------------------------------\n", + "age -0.1031 0.060 -1.728 0.084 -0.220 0.014\n", + "gender -10.9784 3.356 -3.271 0.001 -17.556 -4.400\n", + "air_pollution -2.6913 1.131 -2.379 0.017 -4.908 -0.474\n", + "genetic_risk 5.0468 1.404 3.594 0.000 2.295 7.799\n", + "chronic_lung_disease -2.7845 0.910 -3.059 0.002 -4.569 -1.000\n", + "smoking -0.9366 0.569 -1.647 0.099 -2.051 0.178\n", + "passive_smoker 1.4782 0.672 2.200 0.028 0.162 2.795\n", + "chest_pain -0.8705 0.682 -1.276 0.202 -2.208 0.467\n", + "coughing_of_blood 0.0447 0.545 0.082 0.935 -1.024 1.113\n", + "fatigue 5.8949 1.561 3.775 0.000 2.835 8.955\n", + "weight_loss -1.6336 0.832 -1.964 0.050 -3.264 -0.003\n", + "shortness_of_breath -0.0884 0.551 -0.160 0.873 -1.169 0.992\n", + "wheezing 0.6777 0.877 0.773 0.440 -1.041 2.397\n", + "swallowing_difficulty 0.5567 0.668 0.834 0.405 -0.752 1.866\n", + "clubbing_of_finger_nails 3.7198 1.152 3.229 0.001 1.462 5.978\n", + "dry_cough -0.6692 0.579 -1.156 0.248 -1.804 0.465\n", + "intercept -5.1110 2.439 -2.096 0.036 -9.891 -0.331\n", + "--------------------------------------------------------------------------------------------\n", + " level=3 coef std err z P>|z| [0.025 0.975]\n", + "--------------------------------------------------------------------------------------------\n", + "age -1.0151 0.427 -2.379 0.017 -1.851 -0.179\n", + "gender -12.5765 3.659 -3.437 0.001 -19.749 -5.404\n", + "air_pollution 1.4011 1.765 0.794 0.427 -2.059 4.861\n", + "genetic_risk 4.5602 4.878 0.935 0.350 -5.000 14.121\n", + "chronic_lung_disease 0.0066 3.619 0.002 0.999 -7.087 7.100\n", + "smoking -3.5495 1.631 -2.177 0.030 -6.746 -0.353\n", + "passive_smoker 11.2581 3.068 3.669 0.000 5.245 17.271\n", + "chest_pain -4.3841 3.281 -1.336 0.182 -10.815 2.047\n", + "coughing_of_blood 1.3247 3.298 0.402 0.688 -5.140 7.789\n", + "fatigue 7.5920 1.937 3.919 0.000 3.795 11.389\n", + "weight_loss -1.4729 2.535 -0.581 0.561 -6.441 3.495\n", + "shortness_of_breath 3.0151 1.000 3.016 0.003 1.056 4.975\n", + "wheezing -3.7620 1.886 -1.994 0.046 -7.459 -0.065\n", + "swallowing_difficulty 0.1687 0.817 0.206 0.836 -1.433 1.771\n", + "clubbing_of_finger_nails 0.6814 2.566 0.266 0.791 -4.348 5.710\n", + "dry_cough -1.6649 1.083 -1.537 0.124 -3.788 0.458\n", + "intercept -10.5903 3.496 -3.029 0.002 -17.442 -3.738\n", + "============================================================================================\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Lina\\Anaconda3\\lib\\site-packages\\statsmodels\\base\\model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", + " \"Check mle_retvals\", ConvergenceWarning)\n" + ] + } + ], + "source": [ + "## Set up the regression\n", + "\n", + "logit = sm.MNLogit(y_train, X_train)\n", + "logit_result = logit.fit(method='bfgs')\n", + "\n", + "## lets get the results\n", + "print(logit_result.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Coeffieients\n", + " 0 1\n", + "age -0.103143 -1.015068\n", + "gender -10.978373 -12.576457\n", + "air_pollution -2.691276 1.401134\n", + "genetic_risk 5.046776 4.560153\n", + "chronic_lung_disease -2.784542 0.006603\n", + "smoking -0.936573 -3.549460\n", + "passive_smoker 1.478245 11.258127\n", + "chest_pain -0.870475 -4.384082\n", + "coughing_of_blood 0.044676 1.324661\n", + "fatigue 5.894866 7.592025\n", + "weight_loss -1.633568 -1.472910\n", + "shortness_of_breath -0.088386 3.015135\n", + "wheezing 0.677735 -3.762030\n", + "swallowing_difficulty 0.556677 0.168725\n", + "clubbing_of_finger_nails 3.719845 0.681413\n", + "dry_cough -0.669196 -1.664935\n", + "intercept -5.111022 -10.590282\n", + "\n", + "\n", + "p-Values\n", + " 0 1\n", + "age 0.083975 0.017339\n", + "gender 0.001071 0.000589\n", + "air_pollution 0.017339 0.427364\n", + "genetic_risk 0.000325 0.349867\n", + "chronic_lung_disease 0.002224 0.998544\n", + "smoking 0.099470 0.029504\n", + "passive_smoker 0.027775 0.000243\n", + "chest_pain 0.201996 0.181521\n", + "coughing_of_blood 0.934670 0.687977\n", + "fatigue 0.000160 0.000089\n", + "weight_loss 0.049567 0.561152\n", + "shortness_of_breath 0.872586 0.002563\n", + "wheezing 0.439681 0.046131\n", + "swallowing_difficulty 0.404531 0.836471\n", + "clubbing_of_finger_nails 0.001242 0.790572\n", + "dry_cough 0.247558 0.124328\n", + "intercept 0.036120 0.002451\n", + "\n", + "\n", + "Dependent variables\n", + "level\n" + ] + } + ], + "source": [ + "print(\"Coeffieients\")\n", + "print(logit_result.params)\n", + "print (\"\\n\")\n", + "print(\"p-Values\")\n", + "print(logit_result.pvalues)\n", + "print (\"\\n\")\n", + "print(\"Dependent variables\")\n", + "print(logit.endog_names)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interpreting logistic regression coefficients.\n", + "In this case, using the odds ratio will help us understand how 1 unit of increase or decrease in any of the variables affects the odds of being admitted." + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 0 1\n", + "age 0.901998 0.362378\n", + "gender 0.000017 0.000003\n", + "air_pollution 0.067794 4.059801\n", + "genetic_risk 155.520200 95.598101\n", + "chronic_lung_disease 0.061757 1.006625\n", + "smoking 0.391969 0.028740\n", + "passive_smoker 4.385243 77507.296804\n", + "chest_pain 0.418752 0.012474\n", + "coughing_of_blood 1.045689 3.760909\n", + "fatigue 363.168194 1982.323098\n", + "weight_loss 0.195232 0.229257\n", + "shortness_of_breath 0.915408 20.391840\n", + "wheezing 1.969412 0.023237\n", + "swallowing_difficulty 1.744865 1.183795\n", + "clubbing_of_finger_nails 41.257983 1.976668\n", + "dry_cough 0.512120 0.189203\n", + "intercept 0.006030 0.000025\n" + ] + } + ], + "source": [ + "print (np.exp(logit_result.params))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These values are from our train set, now lets predict on our test set" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<span style=\"color:red\">Please explain more about the coefficients and p-values and what they mean, e.g. which features are most important? which has a higher influence on each level based on coefficients?</span>" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Predicting and Evaluating\n", + "If we call the predict method, we will get the predictive probabilities. But to make a prediction if a patient has a Low, Medium, High Lung cancer we must convert these predicted probabilities into class labels 0=Low or 1 = Medium or 2=High. " + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 0 1 2\n", + "204 9.739892e-01 2.601079e-02 5.988912e-25\n", + "71 9.941086e-01 3.641338e-12 5.891434e-03\n", + "594 4.357147e-09 1.000000e+00 1.202880e-27\n", + "672 9.685383e-01 8.177787e-05 3.137988e-02\n", + "14 3.378482e-20 6.732186e-23 1.000000e+00\n", + "64 9.966876e-01 3.312377e-03 1.929552e-26\n", + "340 1.792525e-06 9.999982e-01 2.440023e-22\n", + "135 2.908870e-06 3.631127e-06 9.999935e-01\n", + "350 1.321207e-03 9.986788e-01 1.271600e-26\n", + "976 6.893179e-18 2.766205e-23 1.000000e+00\n" + ] + } + ], + "source": [ + "## Here we have the predictive probabilities\n", + "predictions = logit_result.predict(X_test)\n", + "print(predictions[:10])" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ8AAAEJCAYAAABL3SrKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFmZJREFUeJzt3H9MVff9x/EXA01punKVwQVX1I2iBZ2rulwlVq2wsui+1dq0U2z3dVYrBuxsNn9gY1tJ2yG71tmBUiulmxsu04mFFu2yVpg/gLqsGKyGSpy10gJ36oiB4NDK94+G+y3lAvdeuB+49PlISDyf+/nc877v3NyX59xzbkBTU1O7AAAw6BsDXQAA4OuH8AEAGEf4AACMI3wAAMYRPgAA4wgfAIBxhA8AwDjCBwBgnF+HT21t7UCXMOjQE9foi2v0pSt64lp/98WvwwcA4J8IHwCAcYQPAMA4wgcAYFyv4bNt2zbNmTNHUVFRio6O1qJFi3T27Nlen/jMmTOaN2+eIiIiFBsbq6ysLLW38wPaAAA3wuf48eNavny5/vrXv6q4uFhBQUF66KGH9J///KfbNdeuXdPChQsVHh6uI0eOaMuWLcrOzlZOTk6/Fg8A8E9BvU0oLCzstL1r1y6NHj1alZWVmjt3rss1+/fvV2trq3JzcxUcHKy4uDidO3dOO3fu1OrVqxUQENA/1QMA/JLH3/k0Nzfr1q1bslgs3c45efKk4uPjFRwc7BxLTExUfX29Ll686F2lAIAho9cjn69KT0/X9773Pdlstm7nOBwOjRo1qtNYWFiY87GxY8e6XOfNTUwdaw79sEiXM04rfXSjrL+/rqD7C2S/eUifDDvuHHup7H90+6YQ59hnUdk69MMi59hvSi8q6P4CZWz+k27fFKIl0Qecz5Wx+U+dnr9q026PazWFm+Rcoy+u0Zeu6IlrnvYlJiam28c8Cp9nnnlGlZWVeueddxQYGNjj3K+eWuu42KCnU249FepKbW2tx2v6sr/+WutLfe3JUEVfXKMvXdET1/q7L26Hz8aNG1VYWKi33nqr2yOXDuHh4XI4HJ3GLl++LOn/j4AAAF9fbn3ns2HDBv3lL39RcXGxxo0b1+t8m82miooKXb9+3TlWWlqqyMhIjRkzxvtqAQBDQq/hs3btWu3du1d5eXmyWCxqbGxUY2OjmpubnXMyMjI0f/585/Yjjzyi4OBgpaam6uzZsyouLtb27duVmprKlW4AgN5Pu+Xl5UmSFixY0Gl8w4YN2rhxoySpoaFBFy5ccD4WEhKigwcPau3atZozZ44sFovS0tK0evXq/qwdAOCneg2fpqamXp8kNze3y9iECRN0+PBh76oCAAxp/LYbAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABjnVvicOHFCixcvVmxsrCwWiwoKCnqcf/HiRVksli5/7777br8UDQDwb0HuTGppaVFcXJySk5O1atUqt5/8wIEDmjhxonN7xIgRnlcIABhy3AqfpKQkJSUlSZJSU1PdfvKRI0fKarV6VxkAYMjy6Xc+P/3pT3X33XfrRz/6kYqKiny5KwCAH3HryMdTd9xxh1544QVNnz5dQUFBOnTokJYtW6bc3FwtWrTIF7sEAPgRn4RPaGionnrqKef25MmTdfXqVb3yyis9hk9tba3H+/JmzUCv9bXBXNtAoi+u0Zeu6IlrnvYlJiam28d8Ej6uTJ06tder5Hoq1JXa2lqP1/Rlf/211pf62pOhir64Rl+6oieu9XdfjN3nc/r0aS4+AABIcvPIp7m5Wf/6178kSbdu3VJdXZ2qq6s1YsQIRUVFKSMjQ//85z9VXFwsSdq7d6+GDRumSZMm6Rvf+Ibeeecd5eXlafPmzT57IQAA/+FW+FRVVenBBx90bmdmZiozM1PJycnKzc1VQ0ODLly40GnN1q1bdenSJQUGBio6Olo5OTlcbAAAkORm+MycOVNNTU3dPp6bm9tpe8mSJVqyZEnfKgMADFn8thsAwDjCBwBgHOEDADCO8AEAGEf4AACMI3wAAMYRPgAA4wgfAIBxhA8AwDjCBwBgHOEDADCO8AEAGEf4AACMI3wAAMYRPgAA4wgfAIBxhA8AwDjCBwBgHOEDADCO8AEAGEf4AACMI3wAAMYRPgAA4wgfAIBxhA8AwDjCBwBgHOEDADCO8AEAGEf4AACMI3wAAMYRPgAA4wgfAIBxhA8AwDjCBwBgHOEDADCO8AEAGEf4AACMI3wAAMYRPgAA4wgfAIBxhA8AwDjCBwBgHOEDADCO8AEAGEf4AACMcyt8Tpw4ocWLFys2NlYWi0UFBQW9rjlz5ozmzZuniIgIxcbGKisrS+3t7X0uGADg/9wKn5aWFsXFxWnLli0KDg7udf61a9e0cOFChYeH68iRI9qyZYuys7OVk5PT54IBAP4vyJ1JSUlJSkpKkiSlpqb2On///v1qbW1Vbm6ugoODFRcXp3Pnzmnnzp1avXq1AgIC+lY1AMCv+eQ7n5MnTyo+Pr7TUVJiYqLq6+t18eJFX+wSAOBH3Dry8ZTD4dCoUaM6jYWFhTkfGzt2rMt1tbW1Hu/LmzUdLG98qgwv13qz30M/LNK8dxd4uUf39aUnQxl9ca23vkx+8UlVbdptqJrBYbC+Vya/+KSC7i+Q/eYhfTLsuNJHN8r6++uq2rRbZWVl+mTYcf3vjE2Svvi8uX1TiD4Zdly/Kb2ooPsLlLH5T7p9U4iWRB+Q9ffXexxz9VnlaV9iYmK6fcwn4SOpy6m1josNejrl1lOhrtTW1nq8pr94u19f1zuQPRnM6Itr7vbl69Q7f3yvxMTEqKyszPnv/nrOL+vvvvjktFt4eLgcDkenscuXL0v6/yMgAMDXl0/Cx2azqaKiQtevX3eOlZaWKjIyUmPGjPHFLgEAfsSt8GlublZ1dbWqq6t169Yt1dXVqbq6WpcuXZIkZWRkaP78+c75jzzyiIKDg5WamqqzZ8+quLhY27dvV2pqKle6AQDcC5+qqirNmjVLs2bNUmtrqzIzMzVr1iz96le/kiQ1NDTowoULzvkhISE6ePCg6uvrNWfOHK1bt05paWlavXq1b14FAMCvuHXBwcyZM9XU1NTt47m5uV3GJkyYoMOHD3tfGQBgyOK33QAAxhE+AADjCB8AgHGEDwDAOMIHAGAc4QMAMI7wAQAYR/gAAIwjfAAAxhE+AADjCB8AgHGEDwDAOMIHAGAc4QMAMI7wAQAYR/gAAIwjfAAAxhE+AADjCB8AgHGEDwDAOMIHAGAc4QMAMI7wAQAYR/gAAIwjfAAAxhE+AADjCB8AgHGEDwDAOMIHAGAc4QMAMI7wAQAYR/gAAIwjfAAAxhE+AADjCB8AgHGEDwDAOMIHAGAc4QMAMI7wAQAYR/gAAIwjfAAAxhE+AADjCB8AgHGEDwDAOLfDJy8vT5MmTZLVatXs2bNVXl7e7dxjx47JYrF0+Tt37ly/FA0A8G9B7kwqLCxUenq6Xn75ZU2fPl15eXl69NFHVVlZqaioqG7XVVZWasSIEc7tb33rW32vGADg99w68tmxY4eWLFmipUuXavz48bLb7bJarcrPz+9xXVhYmKxWq/MvMDCwX4oGAPi3XsOnra1Np06dUkJCQqfxhIQEvf/++z2uvf/++zV+/HjNnz9fR48e7VulAIAho9fTbleuXNHnn3+usLCwTuNhYWFyOBwu10RERGjbtm2aMmWK2tra9Oc//1kLFizQ22+/rRkzZvRP5QAAv+XWdz6SFBAQ0Gm7vb29y1iHmJgYxcTEOLdtNps++eQTZWdn9xg+tbW17pbTpzX9wdv9mqh3oHoy2NEX13rry2Q35gw1g/X1Tu5m/Mv19lftrp7H0+f+cg58Va/hExoaqsDAwC5HOZcvX+5yNNSTqVOnqrCwsMc5PRXqSm1trcdr+ou3+/V1vQPZk8GMvrjmbl++Tr3zx/dKTEyMysrKnP/ur+f8sv7uS6/f+QwfPlz33nuvSktLO42XlpZq2rRpbu/o9OnTslqtnlcIABhy3DrtlpaWppSUFE2dOlXTpk1Tfn6+GhoatGzZMklSSkqKJGnXrl2SpJ07d2r06NGKjY1VW1ub9u3bp5KSEu3Zs8dHLwMA4E/cCp+HH35YV69eld1uV2Njo2JjY7Vv3z6NHj1aklRXV9dp/o0bN/Tss8+qvr5et912m3N+UlJS/78CAIDfcfuCgxUrVmjFihUuHyspKem0vWbNGq1Zs6ZvlQEAhix+2w0AYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGET4AAOMIHwCAcYQPAMA4wgcAYBzhAwAwjvABABhH+AAAjCN8AADGuR0+eXl5mjRpkqxWq2bPnq3y8vIe5x8/flyzZ8+W1WrV97//feXn5/e5WADA0OBW+BQWFio9PV2//OUvdfToUdlsNj366KO6dOmSy/kff/yxfvKTn8hms+no0aP6xS9+ofXr16uoqKhfiwcA+Ce3wmfHjh1asmSJli5dqvHjx8tut8tqtXZ7NPPGG28oIiJCdrtd48eP19KlS5WcnKycnJx+LR4A4J8Cmpqa2nua0NbWpsjISL3++ut66KGHnONr167V2bNndejQoS5r5s6dqwkTJmjr1q3OsTfffFMrVqxQfX29hg0b1o8vAQDgb3o98rly5Yo+//xzhYWFdRoPCwuTw+FwucbhcLicf/PmTV25cqUP5QIAhgK3LzgICAjotN3e3t5lrLf5rsYBAF8/vYZPaGioAgMDuxzlXL58ucvRTYfw8HCX84OCgjRy5Mg+lAsAGAp6DZ/hw4fr3nvvVWlpaafx0tJSTZs2zeUam82msrKyLvMnT57M9z0AAPdOu6WlpWnv3r3as2ePPvroI23YsEENDQ1atmyZJCklJUUpKSnO+cuWLdNnn32m9PR0ffTRR9qzZ4/27t2r1atX++ZVAAD8ilvh8/DDDyszM1N2u10zZ85UZWWl9u3bp9GjR0uS6urqVFdX55w/duxY7du3T+Xl5Zo5c6a2bt2qrKwsLViwwO3CuKnVNU/6UlxcrIULFyo6Olp33XWXEhMTXV6dOBR4+n7pUFFRodDQUMXHx/u4woHhaV/a2tr00ksvadKkSQoPD9fEiRP16quvGqrWDE97sn//ft13332KjIzUuHHjtHLlSjU2Nhqq1owTJ05o8eLFio2NlcViUUFBQa9rzpw5o3nz5ikiIkKxsbHKyspyfrfvDrcvOFixYoVOnz4th8Ohv//975oxY4bzsZKSEpWUlHSaf9999+no0aNyOByqrq7WE0884XZR3NTqmqd9OXHihGbNmqV9+/bp6NGjeuCBB/T444+7/cHsLzztS4empiatWrVKs2fPNlSpWd70Zfny5Xrvvff0yiuv6B//+Id+97vfacKECQar9i1Pe1JZWamUlBQlJyeroqJCBQUFqqmp0ZNPPmm4ct9qaWlRXFyctmzZouDg4F7nX7t2TQsXLlR4eLiOHDmiLVu2KDs726N7OXu9z2cgJCYmasKECfrtb3/rHJsyZYoWLFig559/vsv8559/Xm+99ZY++OAD59hTTz2lmpoa/e1vfzNSswme9sWVhIQExcfH66WXXvJVmcZ525fHH39cEydOVHt7u4qLi1VRUWGiXGM87cuRI0f0s5/9TFVVVQoNDTVZqjGe9iQ7O1u7du3Shx9+6Bz74x//qA0bNujTTz81UrNp3/72t/XrX/9ajz32WLdzXn/9dW3evFnnzp1zhpXdbld+fr7Onj3r1lXNg+6HRdva2nTq1CklJCR0Gk9ISND777/vcs3Jkye7zE9MTFRVVZVu3Ljhs1pN8qYvrjQ3N8tisfR3eQPG277k5eXJ4XBo3bp1vi5xQHjTl5KSEk2ePFk7duxQXFycpkyZovXr16u5udlEyT7nTU+mTZumxsZGHT58WO3t7bpy5YoKCwv1wAMPmCh50Dp58qTi4+M7HSUlJiaqvr5eFy9edOs5Bl34cFOra9705at2796tzz77TIsWLfJFiQPCm76cOXNGWVlZeu211xQYGGiiTOO86cvHH3+syspKffjhh9qzZ4/sdrvee+89paammijZ57zpic1mU15enlauXKmwsDBFR0ervb1dubm5JkoetLr7zO14zB2DLnw6cFOra572pUNRUZGee+45vfbaa84LRYYSd/vy3//+V8uXL9cLL7ygsWPHGqpu4Hjyfrl165YCAgK0e/du/eAHP1BiYqLsdruKi4vd/kDxB570pKamRunp6Vq3bp3Kysp04MABNTY26umnnzZR6qDW18/coH6vqI+4qdU1b/rSoaioSKtWrdKrr76qefPm+bJM4zztS0NDg2pqapSWlqa0tDRJX3zotre3KzQ0VPv37+9yWsYfefN+sVqtioyMVEhIiHNs3Lhxkr64ojU8PNx3BRvgTU+2bdumKVOm6Oc//7kkaeLEibr99ts1d+5cPfvss7rrrrt8Xvdg1N1nrqReP486DLojH25qdc2bvkjSwYMHlZKSop07d3p0qbu/8LQvo0aNUnl5uY4dO+b8e+KJJ/Td735Xx44dk81mM1W6T3nzfpk+fboaGho6fcdz/vx5SVJUVJTvijXEm560trZ2OTXbse3JZcVDjc1mU0VFha5fv+4cKy0tVWRkpMaMGePWcwSmp6dv9lF9XvvmN7+pzMxMRURE6LbbbpPdbld5eblycnIUEhKilJQUvf3223rwwQclSd/5zne0fft2/fvf/1ZUVJQOHTqkl19+WS+++KLuueeeAX41/cfTvhw4cEArV65URkaGkpKS1NLSopaWFt24ccOtyyn9hSd9CQwMVFhYWKe/Dz74QOfPn9fGjRs1fPjwgX45/cbT98vdd9+tgoICnTp1Svfcc4/Onz+vdevWacaMGT1e+eRPPO1Ja2ursrOzFRoaqpEjRzpPw1mtVq1Zs2aAX03/aW5uVk1NjRobG/WHP/xBcXFxuvPOO9XW1qaQkBBlZGRo27ZtSk5OliRFR0frjTfe0OnTpxUTE6OKigo999xzevrpp3v8z/CXDbrTbtIXN7VevXpVdrtdjY2Nio2N7XJT65d13NT6zDPPKD8/XxERER7f1OoPPO1Lfn6+bt68qY0bN2rjxo3O8RkzZnS5L8ufedqXrwtP+3LHHXfozTff1Pr165WQkCCLxaIf//jHbl/G7w887cljjz2m5uZm7d69W5s2bdKdd96pmTNnKiMjYyDK95mqqipn4EpSZmamMjMzlZycrNzcXDU0NOjChQvOx0NCQnTw4EGtXbtWc+bMkcViUVpamke/YjMo7/MBAAxtg+47HwDA0Ef4AACMI3wAAMYRPgAA4wgfAIBxhA8AwDjCBwBgHOEDADCO8AEAGPd/L8rwPGVvLckAAAAASUVORK5CYII=\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(predictions);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Confusion matrix and Classification report\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", + " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", + " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", + " verbose=0, warm_start=False)" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)\n", + "logreg = LogisticRegression()\n", + "logreg.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy of logistic regression classifier on test set: 0.97\n" + ] + } + ], + "source": [ + "y_pred = logreg.predict(X_test)\n", + "print('Accuracy of logistic regression classifier on test set: {:.2f}'.format(logreg.score(X_test, y_test)))" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "length of oversampled data is 720\n", + "Number of no subscription in oversampled data 0\n", + "Number of subscription 240\n", + "Proportion of no subscription data in oversampled data is 0.0\n", + "Proportion of subscription data in oversampled data is 0.3333333333333333\n" + ] + } + ], + "source": [ + "from imblearn.over_sampling import SMOTE\n", + "columns = X_train.columns\n", + "os = SMOTE(random_state=0)\n", + "os_data_X,os_data_y=os.fit_sample(X_train, y_train)\n", + "os_data_X = pd.DataFrame(data=os_data_X,columns=columns )\n", + "os_data_y= pd.DataFrame(data=os_data_y,columns=['y'])\n", + "# we can Check the numbers of our data\n", + "print(\"length of oversampled data is \",len(os_data_X))\n", + "print(\"Number of no subscription in oversampled data\",len(os_data_y[os_data_y['y']==0]))\n", + "print(\"Number of subscription\",len(os_data_y[os_data_y['y']==1]))\n", + "print(\"Proportion of no subscription data in oversampled data is \",len(os_data_y[os_data_y['y']==0])/len(os_data_X))\n", + "print(\"Proportion of subscription data in oversampled data is \",len(os_data_y[os_data_y['y']==1])/len(os_data_X))" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ True True True True True True True True True True True True\n", + " True True True True True]\n", + "[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n" + ] + } + ], + "source": [ + "from sklearn.feature_selection import RFE\n", + "logreg = LogisticRegression()\n", + "rfe = RFE(logreg, 20)\n", + "rfe = rfe.fit(os_data_X, os_data_y.values.ravel())\n", + "print(rfe.support_)\n", + "print(rfe.ranking_)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 76 0 0]\n", + " [ 4 83 5]\n", + " [ 0 0 132]]\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>Predict_Label_0 Low</th>\n", + " <th>Predict_Label_1 Medium</th>\n", + " <th>Predict_Label_2 High</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>True_Label_0 Low</th>\n", + " <td>76</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>True_Label_1 Medium</th>\n", + " <td>4</td>\n", + " <td>83</td>\n", + " <td>5</td>\n", + " </tr>\n", + " <tr>\n", + " <th>True_Label_2 High</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>132</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " Predict_Label_0 Low Predict_Label_1 Medium \\\n", + "True_Label_0 Low 76 0 \n", + "True_Label_1 Medium 4 83 \n", + "True_Label_2 High 0 0 \n", + "\n", + " Predict_Label_2 High \n", + "True_Label_0 Low 0 \n", + "True_Label_1 Medium 5 \n", + "True_Label_2 High 132 " + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "confusion_matrix = confusion_matrix(y_test, y_pred)\n", + "print(confusion_matrix)\n", + "confusion = pd.DataFrame(confusion_matrix,index=['True_Label_0 Low', 'True_Label_1 Medium','True_Label_2 High'],\n", + " columns=['Predict_Label_0 Low', 'Predict_Label_1 Medium','Predict_Label_2 High'])\n", + "\n", + "confusion" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 40 0 0 0 0 0 0]\n", + " [121 91 0 0 0 0 0]\n", + " [ 92 81 0 0 0 0 0]\n", + " [ 20 20 0 0 0 0 0]\n", + " [ 0 20 80 0 0 0 0]\n", + " [ 20 20 68 0 0 0 0]\n", + " [ 10 100 217 0 0 0 0]]\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>level Low</th>\n", + " <th>level Medium</th>\n", + " <th>level High</th>\n", + " <th>Predected_Label_4</th>\n", + " <th>Predected_Label_5</th>\n", + " <th>Predected_Label_6</th>\n", + " <th>Predected_Label_7</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>genetic_risk_level_1</th>\n", + " <td>40</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>genetic_risk_level_2</th>\n", + " <td>121</td>\n", + " <td>91</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>genetic_risk_level_3</th>\n", + " <td>92</td>\n", + " <td>81</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>genetic_risk_level_4</th>\n", + " <td>20</td>\n", + " <td>20</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>genetic_risk_level_5</th>\n", + " <td>0</td>\n", + " <td>20</td>\n", + " <td>80</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>genetic_risk_level_6</th>\n", + " <td>20</td>\n", + " <td>20</td>\n", + " <td>68</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>genetic_risk_level_7</th>\n", + " <td>10</td>\n", + " <td>100</td>\n", + " <td>217</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " level Low level Medium level High Predected_Label_4 \\\n", + "genetic_risk_level_1 40 0 0 0 \n", + "genetic_risk_level_2 121 91 0 0 \n", + "genetic_risk_level_3 92 81 0 0 \n", + "genetic_risk_level_4 20 20 0 0 \n", + "genetic_risk_level_5 0 20 80 0 \n", + "genetic_risk_level_6 20 20 68 0 \n", + "genetic_risk_level_7 10 100 217 0 \n", + "\n", + " Predected_Label_5 Predected_Label_6 Predected_Label_7 \n", + "genetic_risk_level_1 0 0 0 \n", + "genetic_risk_level_2 0 0 0 \n", + "genetic_risk_level_3 0 0 0 \n", + "genetic_risk_level_4 0 0 0 \n", + "genetic_risk_level_5 0 0 0 \n", + "genetic_risk_level_6 0 0 0 \n", + "genetic_risk_level_7 0 0 0 " + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "confusion_matrix = confusion_matrix(LUNG_CANCER.genetic_risk, LUNG_CANCER.level)\n", + "print(confusion_matrix)\n", + "confusion = pd.DataFrame(confusion_matrix,index=['genetic_risk_level_1', 'genetic_risk_level_2','genetic_risk_level_3','genetic_risk_level_4', 'genetic_risk_level_5','genetic_risk_level_6','genetic_risk_level_7'],\n", + " columns=['level Low', 'level Medium','level High','Predected_Label_4', 'Predected_Label_5','Predected_Label_6','Predected_Label_7'])\n", + "\n", + "confusion" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", + " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", + " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", + " verbose=0, warm_start=False)" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)\n", + "logreg = LogisticRegression()\n", + "logreg.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy of logistic regression classifier on test set: 0.97\n" + ] + } + ], + "source": [ + "y_pred = logreg.predict(X_test)\n", + "print('Accuracy of logistic regression classifier on test set: {:.2f}'.format(logreg.score(X_test, y_test)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute precision, recall, F-measure and support" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 1 0.950 1.000 0.974 76\n", + " 2 1.000 0.902 0.949 92\n", + " 3 0.964 1.000 0.981 132\n", + "\n", + "avg / total 0.971 0.970 0.970 300\n", + "\n" + ] + } + ], + "source": [ + "from sklearn.metrics import classification_report\n", + "print (classification_report(y_test, y_pred, digits=3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Interpretation: Of the entire test set, 97% of patients have level cancer High." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Lets implement the same logistic regression using scikit learn\n", + "\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3 365\n", + "2 332\n", + "1 303\n", + "Name: level, dtype: int64 \n", + "\n" + ] + } + ], + "source": [ + "'''Remeber that 1 is Low, 2 is Medium, 3 is High'''\n", + "print (LUNG_CANCER['level'].value_counts(), \"\\n\" )" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", + " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", + " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", + " verbose=0, warm_start=False)" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logistic = LogisticRegression()\n", + "logistic.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 76 0 0]\n", + " [ 4 83 5]\n", + " [ 0 0 132]]\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>Predict_Label_0 Low</th>\n", + " <th>Predict_Label_1 Medium</th>\n", + " <th>Predict_Label_2 High</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>True_Label_0 Low</th>\n", + " <td>76</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>True_Label_1 Medium</th>\n", + " <td>4</td>\n", + " <td>83</td>\n", + " <td>5</td>\n", + " </tr>\n", + " <tr>\n", + " <th>True_Label_2 High</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>132</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " Predict_Label_0 Low Predict_Label_1 Medium \\\n", + "True_Label_0 Low 76 0 \n", + "True_Label_1 Medium 4 83 \n", + "True_Label_2 High 0 0 \n", + "\n", + " Predict_Label_2 High \n", + "True_Label_0 Low 0 \n", + "True_Label_1 Medium 5 \n", + "True_Label_2 High 132 " + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "y_pred=logistic.predict(X_test)\n", + "confusion_matrix = confusion_matrix(y_test, y_pred)\n", + "print(confusion_matrix)\n", + "confusion = pd.DataFrame(confusion_matrix,index=['True_Label_0 Low', 'True_Label_1 Medium','True_Label_2 High'],\n", + " columns=['Predict_Label_0 Low', 'Predict_Label_1 Medium','Predict_Label_2 High'])\n", + "\n", + "confusion" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 1 0.950 1.000 0.974 76\n", + " 2 1.000 0.902 0.949 92\n", + " 3 0.964 1.000 0.981 132\n", + "\n", + "avg / total 0.971 0.970 0.970 300\n", + "\n" + ] + } + ], + "source": [ + "print (classification_report(y_test, y_pred, digits=3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Interpretation: Of the entire test set, 97% of patients have level cancer High." + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", + " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", + " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", + " verbose=0, warm_start=False)" + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''Use scikit learn'''\n", + "r_d_logistic = LogisticRegression()\n", + "r_d_logistic.fit(X_train, y_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculate accuracy, Misclassification Rate (Error Rate), Precision, Recall\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy score: 97.000\n" + ] + } + ], + "source": [ + "## Accuracy\n", + "## How often is the classifier correct?\n", + "from sklearn.metrics import accuracy_score\n", + "\n", + "acc = accuracy_score(y_test, y_pred)\n", + "print (\"Accuracy score: %.3f\" %(acc*100))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}