[0f681c]: / scripts / Iftah_Classification Analysis_full_features.ipynb

Download this file

1466 lines (1465 with data), 157.2 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from seaborn import set_style\n",
    "\n",
    "## This sets the plot style\n",
    "## to have a grid on a white background\n",
    "set_style(\"white\")\n",
    "from sklearn.metrics import recall_score, precision_score, roc_auc_score, accuracy_score, roc_curve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set up datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(\"time_series_data_v1.csv\")\n",
    "data_fft = pd.read_csv(\"time_series_data_v1_fft.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>...</th>\n",
       "      <th>7491</th>\n",
       "      <th>7492</th>\n",
       "      <th>7493</th>\n",
       "      <th>7494</th>\n",
       "      <th>7495</th>\n",
       "      <th>7496</th>\n",
       "      <th>7497</th>\n",
       "      <th>7498</th>\n",
       "      <th>7499</th>\n",
       "      <th>7500</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>-1.919012</td>\n",
       "      <td>-1.919012</td>\n",
       "      <td>-1.919012</td>\n",
       "      <td>-1.919012</td>\n",
       "      <td>37.580988</td>\n",
       "      <td>37.179754</td>\n",
       "      <td>37.580988</td>\n",
       "      <td>77.580988</td>\n",
       "      <td>77.080988</td>\n",
       "      <td>...</td>\n",
       "      <td>10.779934</td>\n",
       "      <td>12.404934</td>\n",
       "      <td>14.404934</td>\n",
       "      <td>15.904934</td>\n",
       "      <td>16.404934</td>\n",
       "      <td>17.404934</td>\n",
       "      <td>18.904934</td>\n",
       "      <td>18.654934</td>\n",
       "      <td>17.904934</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>-2.056393</td>\n",
       "      <td>-2.056393</td>\n",
       "      <td>-2.056393</td>\n",
       "      <td>-2.056393</td>\n",
       "      <td>37.456107</td>\n",
       "      <td>37.443607</td>\n",
       "      <td>37.443607</td>\n",
       "      <td>76.943607</td>\n",
       "      <td>76.943607</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.243339</td>\n",
       "      <td>-0.743339</td>\n",
       "      <td>-0.243339</td>\n",
       "      <td>0.756661</td>\n",
       "      <td>1.256661</td>\n",
       "      <td>2.256661</td>\n",
       "      <td>3.756661</td>\n",
       "      <td>4.256661</td>\n",
       "      <td>4.756661</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>-1.527552</td>\n",
       "      <td>-1.527552</td>\n",
       "      <td>-1.527552</td>\n",
       "      <td>-1.527552</td>\n",
       "      <td>37.972448</td>\n",
       "      <td>37.571214</td>\n",
       "      <td>37.972448</td>\n",
       "      <td>77.972448</td>\n",
       "      <td>77.472448</td>\n",
       "      <td>...</td>\n",
       "      <td>-38.638512</td>\n",
       "      <td>-39.538512</td>\n",
       "      <td>-40.538512</td>\n",
       "      <td>-40.538512</td>\n",
       "      <td>-39.438512</td>\n",
       "      <td>-38.538512</td>\n",
       "      <td>-37.038512</td>\n",
       "      <td>-33.038512</td>\n",
       "      <td>-32.913512</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>-1.683246</td>\n",
       "      <td>-1.683246</td>\n",
       "      <td>-1.683246</td>\n",
       "      <td>-1.683246</td>\n",
       "      <td>37.816754</td>\n",
       "      <td>37.415520</td>\n",
       "      <td>37.816754</td>\n",
       "      <td>77.816754</td>\n",
       "      <td>77.316754</td>\n",
       "      <td>...</td>\n",
       "      <td>5.155963</td>\n",
       "      <td>5.405963</td>\n",
       "      <td>5.905963</td>\n",
       "      <td>6.572629</td>\n",
       "      <td>6.655963</td>\n",
       "      <td>7.405963</td>\n",
       "      <td>8.405963</td>\n",
       "      <td>8.905963</td>\n",
       "      <td>8.905963</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>-1.496040</td>\n",
       "      <td>-1.496040</td>\n",
       "      <td>-1.496040</td>\n",
       "      <td>-1.496040</td>\n",
       "      <td>38.003960</td>\n",
       "      <td>37.602726</td>\n",
       "      <td>38.003960</td>\n",
       "      <td>78.003960</td>\n",
       "      <td>77.503960</td>\n",
       "      <td>...</td>\n",
       "      <td>-2.816436</td>\n",
       "      <td>-2.816436</td>\n",
       "      <td>-2.816436</td>\n",
       "      <td>-2.316436</td>\n",
       "      <td>-1.816436</td>\n",
       "      <td>-1.983103</td>\n",
       "      <td>-1.816436</td>\n",
       "      <td>-1.816436</td>\n",
       "      <td>-2.316436</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 7502 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0         0         1         2         3          4          5  \\\n",
       "0           0 -1.919012 -1.919012 -1.919012 -1.919012  37.580988  37.179754   \n",
       "1           1 -2.056393 -2.056393 -2.056393 -2.056393  37.456107  37.443607   \n",
       "2           2 -1.527552 -1.527552 -1.527552 -1.527552  37.972448  37.571214   \n",
       "3           3 -1.683246 -1.683246 -1.683246 -1.683246  37.816754  37.415520   \n",
       "4           4 -1.496040 -1.496040 -1.496040 -1.496040  38.003960  37.602726   \n",
       "\n",
       "           6          7          8  ...       7491       7492       7493  \\\n",
       "0  37.580988  77.580988  77.080988  ...  10.779934  12.404934  14.404934   \n",
       "1  37.443607  76.943607  76.943607  ...  -1.243339  -0.743339  -0.243339   \n",
       "2  37.972448  77.972448  77.472448  ... -38.638512 -39.538512 -40.538512   \n",
       "3  37.816754  77.816754  77.316754  ...   5.155963   5.405963   5.905963   \n",
       "4  38.003960  78.003960  77.503960  ...  -2.816436  -2.816436  -2.816436   \n",
       "\n",
       "        7494       7495       7496       7497       7498       7499  7500  \n",
       "0  15.904934  16.404934  17.404934  18.904934  18.654934  17.904934   0.0  \n",
       "1   0.756661   1.256661   2.256661   3.756661   4.256661   4.756661   0.0  \n",
       "2 -40.538512 -39.438512 -38.538512 -37.038512 -33.038512 -32.913512   0.0  \n",
       "3   6.572629   6.655963   7.405963   8.405963   8.905963   8.905963   0.0  \n",
       "4  -2.316436  -1.816436  -1.983103  -1.816436  -1.816436  -2.316436   0.0  \n",
       "\n",
       "[5 rows x 7502 columns]"
      ]
     },
     "execution_count": 184,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>...</th>\n",
       "      <th>7491</th>\n",
       "      <th>7492</th>\n",
       "      <th>7493</th>\n",
       "      <th>7494</th>\n",
       "      <th>7495</th>\n",
       "      <th>7496</th>\n",
       "      <th>7497</th>\n",
       "      <th>7498</th>\n",
       "      <th>7499</th>\n",
       "      <th>7500</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>8.810730e-13</td>\n",
       "      <td>1754.620136</td>\n",
       "      <td>1111.232507</td>\n",
       "      <td>877.424567</td>\n",
       "      <td>1612.482093</td>\n",
       "      <td>1496.972754</td>\n",
       "      <td>1290.561420</td>\n",
       "      <td>819.749901</td>\n",
       "      <td>1041.156190</td>\n",
       "      <td>...</td>\n",
       "      <td>1667.542984</td>\n",
       "      <td>911.635367</td>\n",
       "      <td>2267.761854</td>\n",
       "      <td>1798.143874</td>\n",
       "      <td>1976.902999</td>\n",
       "      <td>2528.165282</td>\n",
       "      <td>1721.001574</td>\n",
       "      <td>3034.219976</td>\n",
       "      <td>1726.239727</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>5.258016e-12</td>\n",
       "      <td>1717.776454</td>\n",
       "      <td>1575.638489</td>\n",
       "      <td>1551.712801</td>\n",
       "      <td>1049.344949</td>\n",
       "      <td>1242.986193</td>\n",
       "      <td>1173.008962</td>\n",
       "      <td>1595.635646</td>\n",
       "      <td>807.680360</td>\n",
       "      <td>...</td>\n",
       "      <td>857.542235</td>\n",
       "      <td>1317.682629</td>\n",
       "      <td>753.778000</td>\n",
       "      <td>2089.817295</td>\n",
       "      <td>926.026290</td>\n",
       "      <td>955.332643</td>\n",
       "      <td>2225.841039</td>\n",
       "      <td>881.140039</td>\n",
       "      <td>1866.132179</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>3.666401e-12</td>\n",
       "      <td>1222.182666</td>\n",
       "      <td>1160.233433</td>\n",
       "      <td>1075.590271</td>\n",
       "      <td>1604.629377</td>\n",
       "      <td>1295.074248</td>\n",
       "      <td>1460.742239</td>\n",
       "      <td>2711.975076</td>\n",
       "      <td>720.778185</td>\n",
       "      <td>...</td>\n",
       "      <td>475.877793</td>\n",
       "      <td>561.915265</td>\n",
       "      <td>1011.705599</td>\n",
       "      <td>921.503744</td>\n",
       "      <td>831.159064</td>\n",
       "      <td>900.202411</td>\n",
       "      <td>903.285798</td>\n",
       "      <td>863.718676</td>\n",
       "      <td>977.430996</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>3.765876e-12</td>\n",
       "      <td>1402.132765</td>\n",
       "      <td>1482.460656</td>\n",
       "      <td>1269.357953</td>\n",
       "      <td>1423.200436</td>\n",
       "      <td>1098.530975</td>\n",
       "      <td>1201.729242</td>\n",
       "      <td>1233.002046</td>\n",
       "      <td>1588.659021</td>\n",
       "      <td>...</td>\n",
       "      <td>1567.586264</td>\n",
       "      <td>1529.488217</td>\n",
       "      <td>1087.482628</td>\n",
       "      <td>1626.073792</td>\n",
       "      <td>1114.458183</td>\n",
       "      <td>1633.220659</td>\n",
       "      <td>1771.324525</td>\n",
       "      <td>1464.746585</td>\n",
       "      <td>1276.345344</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>2.444267e-12</td>\n",
       "      <td>3150.845355</td>\n",
       "      <td>1139.146049</td>\n",
       "      <td>1288.776028</td>\n",
       "      <td>1391.246330</td>\n",
       "      <td>1552.294539</td>\n",
       "      <td>998.176271</td>\n",
       "      <td>762.994498</td>\n",
       "      <td>1182.613670</td>\n",
       "      <td>...</td>\n",
       "      <td>1006.640784</td>\n",
       "      <td>2071.525278</td>\n",
       "      <td>1989.091452</td>\n",
       "      <td>1410.119838</td>\n",
       "      <td>2846.517129</td>\n",
       "      <td>3822.363668</td>\n",
       "      <td>3203.469711</td>\n",
       "      <td>1959.778941</td>\n",
       "      <td>1199.812164</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 7502 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0             0            1            2            3  \\\n",
       "0           0  8.810730e-13  1754.620136  1111.232507   877.424567   \n",
       "1           1  5.258016e-12  1717.776454  1575.638489  1551.712801   \n",
       "2           2  3.666401e-12  1222.182666  1160.233433  1075.590271   \n",
       "3           3  3.765876e-12  1402.132765  1482.460656  1269.357953   \n",
       "4           4  2.444267e-12  3150.845355  1139.146049  1288.776028   \n",
       "\n",
       "             4            5            6            7            8  ...  \\\n",
       "0  1612.482093  1496.972754  1290.561420   819.749901  1041.156190  ...   \n",
       "1  1049.344949  1242.986193  1173.008962  1595.635646   807.680360  ...   \n",
       "2  1604.629377  1295.074248  1460.742239  2711.975076   720.778185  ...   \n",
       "3  1423.200436  1098.530975  1201.729242  1233.002046  1588.659021  ...   \n",
       "4  1391.246330  1552.294539   998.176271   762.994498  1182.613670  ...   \n",
       "\n",
       "          7491         7492         7493         7494         7495  \\\n",
       "0  1667.542984   911.635367  2267.761854  1798.143874  1976.902999   \n",
       "1   857.542235  1317.682629   753.778000  2089.817295   926.026290   \n",
       "2   475.877793   561.915265  1011.705599   921.503744   831.159064   \n",
       "3  1567.586264  1529.488217  1087.482628  1626.073792  1114.458183   \n",
       "4  1006.640784  2071.525278  1989.091452  1410.119838  2846.517129   \n",
       "\n",
       "          7496         7497         7498         7499  7500  \n",
       "0  2528.165282  1721.001574  3034.219976  1726.239727   0.0  \n",
       "1   955.332643  2225.841039   881.140039  1866.132179   0.0  \n",
       "2   900.202411   903.285798   863.718676   977.430996   0.0  \n",
       "3  1633.220659  1771.324525  1464.746585  1276.345344   0.0  \n",
       "4  3822.363668  3203.469711  1959.778941  1199.812164   0.0  \n",
       "\n",
       "[5 rows x 7502 columns]"
      ]
     },
     "execution_count": 185,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_fft.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = data.iloc[:,1:-1].to_numpy()\n",
    "y = data.iloc[:,-1].to_numpy()\n",
    "y[y !=0] = 1\n",
    "X_fft = data_fft.iloc[:,1:-1].to_numpy()\n",
    "y_fft = data_fft.iloc[:,-1].to_numpy()\n",
    "y_fft[y_fft !=0] = 1\n",
    "\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, \n",
    "                                                    test_size=0.3, \n",
    "                                                    random_state=111,\n",
    "                                                    stratify=y)\n",
    "\n",
    "X_fft_train, X_fft_test, y_fft_train, y_fft_test = train_test_split(X_fft, y_fft, \n",
    "                                                    test_size=0.3, \n",
    "                                                    random_state=111,\n",
    "                                                    stratify=y_fft)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# KNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7357142857142858\n",
      "0.9174757281553398\n",
      "0.5294117647058824\n"
     ]
    }
   ],
   "source": [
    "knn = KNeighborsClassifier(n_neighbors = 3)\n",
    "knn.fit(X_train,y_train)\n",
    "y_pred = knn.predict(X_test)\n",
    "print( accuracy_score(y_test,y_pred))\n",
    "print( precision_score(y_test,y_pred))\n",
    "print( recall_score(y_test,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.82\n",
      "0.8786885245901639\n",
      "0.7507002801120448\n"
     ]
    }
   ],
   "source": [
    "knn = KNeighborsClassifier(n_neighbors = 3)\n",
    "knn.fit(X_fft_train,y_fft_train)\n",
    "y_pred = knn.predict(X_fft_test)\n",
    "print( accuracy_score(y_fft_test,y_pred))\n",
    "print( precision_score(y_fft_test,y_pred))\n",
    "print( recall_score(y_fft_test,y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random Forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8771428571428571\n",
      "0.8613333333333333\n",
      "0.9047619047619048\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "forest_clf = RandomForestClassifier(n_estimators=1000,max_samples=800,max_depth = 100, random_state=614)\n",
    "\n",
    "forest_clf.fit(X_train,y_train)\n",
    "y_pred = forest_clf.predict(X_test)\n",
    "print( accuracy_score(y_test,y_pred))\n",
    "print( precision_score(y_test,y_pred))\n",
    "print( recall_score(y_test,y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 210,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8985714285714286\n",
      "0.8994413407821229\n",
      "0.9019607843137255\n"
     ]
    }
   ],
   "source": [
    "forest_clf = RandomForestClassifier()#n_estimators=1000,max_samples=800,max_depth = 100, random_state=614)\n",
    "\n",
    "forest_clf.fit(X_train,y_train)\n",
    "forest_clf.fit(X_fft_train,y_fft_train)\n",
    "\n",
    "y_pred = forest_clf.predict(X_fft_test)\n",
    "print( accuracy_score(y_fft_test,y_pred))\n",
    "print( precision_score(y_fft_test,y_pred))\n",
    "print( recall_score(y_fft_test,y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import models\n",
    "from keras import layers\n",
    "from keras import optimizers\n",
    "from keras import losses\n",
    "from keras import metrics\n",
    "from keras.utils import to_categorical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.Sequential()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_20\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_81 (Dense)             (None, 16)                120016    \n",
      "_________________________________________________________________\n",
      "dense_82 (Dense)             (None, 16)                272       \n",
      "_________________________________________________________________\n",
      "dense_83 (Dense)             (None, 2)                 34        \n",
      "=================================================================\n",
      "Total params: 120,322\n",
      "Trainable params: 120,322\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.add(layers.Dense(16, activation='relu', input_shape=(7500,)))\n",
    "model.add(layers.Dense(16, activation='relu'))\n",
    "model.add(layers.Dense(2, activation='softmax'))\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 197,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we compile the network like so\n",
    "model.compile(optimizer='rmsprop',\n",
    "                  loss='categorical_crossentropy',\n",
    "                  metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 3.4800 - accuracy: 0.5942 - val_loss: 0.9082 - val_accuracy: 0.7071\n",
      "Epoch 2/100\n",
      "17/17 [==============================] - 0s 17ms/step - loss: 0.4404 - accuracy: 0.8231 - val_loss: 1.1741 - val_accuracy: 0.7600\n",
      "Epoch 3/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 0.4871 - accuracy: 0.8882 - val_loss: 0.9500 - val_accuracy: 0.7757\n",
      "Epoch 4/100\n",
      "17/17 [==============================] - 0s 15ms/step - loss: 0.1894 - accuracy: 0.9422 - val_loss: 0.9177 - val_accuracy: 0.7557\n",
      "Epoch 5/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 0.1837 - accuracy: 0.9429 - val_loss: 1.0315 - val_accuracy: 0.7871\n",
      "Epoch 6/100\n",
      "17/17 [==============================] - 0s 26ms/step - loss: 0.2868 - accuracy: 0.9291 - val_loss: 1.0552 - val_accuracy: 0.8029\n",
      "Epoch 7/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 0.1027 - accuracy: 0.9667 - val_loss: 1.0030 - val_accuracy: 0.7986\n",
      "Epoch 8/100\n",
      "17/17 [==============================] - 0s 19ms/step - loss: 0.0357 - accuracy: 0.9914 - val_loss: 1.0517 - val_accuracy: 0.8100\n",
      "Epoch 9/100\n",
      "17/17 [==============================] - 0s 24ms/step - loss: 0.0509 - accuracy: 0.9854 - val_loss: 1.3149 - val_accuracy: 0.7957\n",
      "Epoch 10/100\n",
      "17/17 [==============================] - 0s 18ms/step - loss: 0.1221 - accuracy: 0.9743 - val_loss: 1.4089 - val_accuracy: 0.7929\n",
      "Epoch 11/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 0.1572 - accuracy: 0.9748 - val_loss: 1.4016 - val_accuracy: 0.8029\n",
      "Epoch 12/100\n",
      "17/17 [==============================] - 0s 17ms/step - loss: 0.0264 - accuracy: 0.9925 - val_loss: 1.3857 - val_accuracy: 0.8057\n",
      "Epoch 13/100\n",
      "17/17 [==============================] - 1s 32ms/step - loss: 0.0701 - accuracy: 0.9844 - val_loss: 1.5484 - val_accuracy: 0.8057\n",
      "Epoch 14/100\n",
      "17/17 [==============================] - 1s 40ms/step - loss: 0.0547 - accuracy: 0.9869 - val_loss: 1.4543 - val_accuracy: 0.7871\n",
      "Epoch 15/100\n",
      "17/17 [==============================] - 0s 29ms/step - loss: 0.0352 - accuracy: 0.9904 - val_loss: 1.4931 - val_accuracy: 0.8343\n",
      "Epoch 16/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 0.0199 - accuracy: 0.9927 - val_loss: 1.6465 - val_accuracy: 0.8400\n",
      "Epoch 17/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0331 - accuracy: 0.9964 - val_loss: 1.7887 - val_accuracy: 0.8100\n",
      "Epoch 18/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 0.0470 - accuracy: 0.9933 - val_loss: 1.6822 - val_accuracy: 0.8114\n",
      "Epoch 19/100\n",
      "17/17 [==============================] - 0s 28ms/step - loss: 0.0075 - accuracy: 0.9986 - val_loss: 1.7925 - val_accuracy: 0.8229\n",
      "Epoch 20/100\n",
      "17/17 [==============================] - 0s 21ms/step - loss: 0.0372 - accuracy: 0.9900 - val_loss: 1.7763 - val_accuracy: 0.8214\n",
      "Epoch 21/100\n",
      "17/17 [==============================] - 1s 36ms/step - loss: 0.0129 - accuracy: 0.9950 - val_loss: 1.8582 - val_accuracy: 0.8029\n",
      "Epoch 22/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 0.0225 - accuracy: 0.9964 - val_loss: 1.8326 - val_accuracy: 0.8214\n",
      "Epoch 23/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 0.0114 - accuracy: 0.9958 - val_loss: 1.9295 - val_accuracy: 0.8114\n",
      "Epoch 24/100\n",
      "17/17 [==============================] - 0s 23ms/step - loss: 0.0263 - accuracy: 0.9917 - val_loss: 2.1135 - val_accuracy: 0.7914\n",
      "Epoch 25/100\n",
      "17/17 [==============================] - 0s 19ms/step - loss: 0.0270 - accuracy: 0.9920 - val_loss: 1.9743 - val_accuracy: 0.8271\n",
      "Epoch 26/100\n",
      "17/17 [==============================] - 0s 27ms/step - loss: 0.0021 - accuracy: 1.0000 - val_loss: 1.9064 - val_accuracy: 0.8286\n",
      "Epoch 27/100\n",
      "17/17 [==============================] - 0s 29ms/step - loss: 2.4204e-04 - accuracy: 1.0000 - val_loss: 1.9064 - val_accuracy: 0.8257\n",
      "Epoch 28/100\n",
      "17/17 [==============================] - 0s 29ms/step - loss: 1.3295e-04 - accuracy: 1.0000 - val_loss: 1.9215 - val_accuracy: 0.8243\n",
      "Epoch 29/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 7.3698e-05 - accuracy: 1.0000 - val_loss: 1.9316 - val_accuracy: 0.8243\n",
      "Epoch 30/100\n",
      "17/17 [==============================] - 0s 16ms/step - loss: 4.3785e-05 - accuracy: 1.0000 - val_loss: 1.9665 - val_accuracy: 0.8271\n",
      "Epoch 31/100\n",
      "17/17 [==============================] - 0s 21ms/step - loss: 2.3023e-05 - accuracy: 1.0000 - val_loss: 1.9833 - val_accuracy: 0.8271\n",
      "Epoch 32/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 1.4395e-05 - accuracy: 1.0000 - val_loss: 2.0359 - val_accuracy: 0.8271\n",
      "Epoch 33/100\n",
      "17/17 [==============================] - 0s 17ms/step - loss: 8.3031e-06 - accuracy: 1.0000 - val_loss: 2.1776 - val_accuracy: 0.8200\n",
      "Epoch 34/100\n",
      "17/17 [==============================] - 0s 24ms/step - loss: 0.0226 - accuracy: 0.9964 - val_loss: 2.4028 - val_accuracy: 0.8229\n",
      "Epoch 35/100\n",
      "17/17 [==============================] - 0s 18ms/step - loss: 0.0145 - accuracy: 0.9952 - val_loss: 2.3524 - val_accuracy: 0.8214\n",
      "Epoch 36/100\n",
      "17/17 [==============================] - 0s 18ms/step - loss: 0.0046 - accuracy: 0.9988 - val_loss: 2.3372 - val_accuracy: 0.8300\n",
      "Epoch 37/100\n",
      "17/17 [==============================] - 0s 23ms/step - loss: 0.1186 - accuracy: 0.9874 - val_loss: 2.3193 - val_accuracy: 0.8329\n",
      "Epoch 38/100\n",
      "17/17 [==============================] - 0s 24ms/step - loss: 0.0016 - accuracy: 0.9998 - val_loss: 2.2864 - val_accuracy: 0.8300\n",
      "Epoch 39/100\n",
      "17/17 [==============================] - 0s 28ms/step - loss: 1.4535e-04 - accuracy: 1.0000 - val_loss: 2.3416 - val_accuracy: 0.8314\n",
      "Epoch 40/100\n",
      "17/17 [==============================] - 0s 18ms/step - loss: 7.5005e-05 - accuracy: 1.0000 - val_loss: 2.3577 - val_accuracy: 0.8300\n",
      "Epoch 41/100\n",
      "17/17 [==============================] - 1s 34ms/step - loss: 2.4917e-05 - accuracy: 1.0000 - val_loss: 2.3551 - val_accuracy: 0.8329\n",
      "Epoch 42/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 1.3168e-05 - accuracy: 1.0000 - val_loss: 2.3660 - val_accuracy: 0.8329\n",
      "Epoch 43/100\n",
      "17/17 [==============================] - 0s 29ms/step - loss: 8.4370e-06 - accuracy: 1.0000 - val_loss: 2.3815 - val_accuracy: 0.8300\n",
      "Epoch 44/100\n",
      "17/17 [==============================] - 1s 39ms/step - loss: 5.0662e-06 - accuracy: 1.0000 - val_loss: 2.3942 - val_accuracy: 0.8314\n",
      "Epoch 45/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 2.4796e-06 - accuracy: 1.0000 - val_loss: 2.4293 - val_accuracy: 0.8329\n",
      "Epoch 46/100\n",
      "17/17 [==============================] - 1s 36ms/step - loss: 1.0115e-06 - accuracy: 1.0000 - val_loss: 2.4676 - val_accuracy: 0.8329\n",
      "Epoch 47/100\n",
      "17/17 [==============================] - 1s 39ms/step - loss: 5.7316e-07 - accuracy: 1.0000 - val_loss: 2.4995 - val_accuracy: 0.8329\n",
      "Epoch 48/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 0.0061 - accuracy: 0.9984 - val_loss: 2.7894 - val_accuracy: 0.8300\n",
      "Epoch 49/100\n",
      "17/17 [==============================] - 0s 28ms/step - loss: 0.0354 - accuracy: 0.9911 - val_loss: 2.6485 - val_accuracy: 0.8329\n",
      "Epoch 50/100\n",
      "17/17 [==============================] - 0s 21ms/step - loss: 0.0115 - accuracy: 0.9988 - val_loss: 2.4846 - val_accuracy: 0.8386\n",
      "Epoch 51/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 0.0042 - accuracy: 0.9983 - val_loss: 2.6672 - val_accuracy: 0.8386\n",
      "Epoch 52/100\n",
      "17/17 [==============================] - 0s 16ms/step - loss: 0.0470 - accuracy: 0.9909 - val_loss: 2.6604 - val_accuracy: 0.8457\n",
      "Epoch 53/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 7.5008e-04 - accuracy: 0.9999 - val_loss: 2.9339 - val_accuracy: 0.8314\n",
      "Epoch 54/100\n",
      "17/17 [==============================] - 0s 30ms/step - loss: 0.0056 - accuracy: 0.9957 - val_loss: 2.9320 - val_accuracy: 0.8329\n",
      "Epoch 55/100\n",
      "17/17 [==============================] - 1s 43ms/step - loss: 0.0055 - accuracy: 0.9974 - val_loss: 3.1210 - val_accuracy: 0.8143\n",
      "Epoch 56/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 0.0107 - accuracy: 0.9963 - val_loss: 2.8705 - val_accuracy: 0.8271\n",
      "Epoch 57/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 2.5691e-04 - accuracy: 1.0000 - val_loss: 2.9095 - val_accuracy: 0.8400\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 58/100\n",
      "17/17 [==============================] - 0s 26ms/step - loss: 7.5608e-05 - accuracy: 1.0000 - val_loss: 2.8919 - val_accuracy: 0.8371\n",
      "Epoch 59/100\n",
      "17/17 [==============================] - 1s 47ms/step - loss: 7.0834e-06 - accuracy: 1.0000 - val_loss: 2.8866 - val_accuracy: 0.8343\n",
      "Epoch 60/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 4.6043e-06 - accuracy: 1.0000 - val_loss: 2.8816 - val_accuracy: 0.8329\n",
      "Epoch 61/100\n",
      "17/17 [==============================] - 0s 24ms/step - loss: 2.9102e-06 - accuracy: 1.0000 - val_loss: 2.8800 - val_accuracy: 0.8329\n",
      "Epoch 62/100\n",
      "17/17 [==============================] - 0s 16ms/step - loss: 1.2977e-06 - accuracy: 1.0000 - val_loss: 2.8847 - val_accuracy: 0.8329\n",
      "Epoch 63/100\n",
      "17/17 [==============================] - 1s 36ms/step - loss: 8.5360e-07 - accuracy: 1.0000 - val_loss: 2.8969 - val_accuracy: 0.8300\n",
      "Epoch 64/100\n",
      "17/17 [==============================] - 0s 17ms/step - loss: 4.6366e-07 - accuracy: 1.0000 - val_loss: 2.9059 - val_accuracy: 0.8314\n",
      "Epoch 65/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 2.0024e-07 - accuracy: 1.0000 - val_loss: 2.9165 - val_accuracy: 0.8329\n",
      "Epoch 66/100\n",
      "17/17 [==============================] - 0s 15ms/step - loss: 1.1975e-07 - accuracy: 1.0000 - val_loss: 2.9278 - val_accuracy: 0.8343\n",
      "Epoch 67/100\n",
      "17/17 [==============================] - 0s 15ms/step - loss: 7.2557e-08 - accuracy: 1.0000 - val_loss: 2.9655 - val_accuracy: 0.8357\n",
      "Epoch 68/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 3.2868e-08 - accuracy: 1.0000 - val_loss: 2.9889 - val_accuracy: 0.8329\n",
      "Epoch 69/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 1.7399e-08 - accuracy: 1.0000 - val_loss: 3.0141 - val_accuracy: 0.8343\n",
      "Epoch 70/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 1.0672e-08 - accuracy: 1.0000 - val_loss: 3.0596 - val_accuracy: 0.8371\n",
      "Epoch 71/100\n",
      "17/17 [==============================] - 0s 19ms/step - loss: 6.3904e-09 - accuracy: 1.0000 - val_loss: 3.1076 - val_accuracy: 0.8371\n",
      "Epoch 72/100\n",
      "17/17 [==============================] - 0s 21ms/step - loss: 8.6816e-09 - accuracy: 1.0000 - val_loss: 3.1891 - val_accuracy: 0.8371\n",
      "Epoch 73/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 6.4090e-09 - accuracy: 1.0000 - val_loss: 3.2443 - val_accuracy: 0.8357\n",
      "Epoch 74/100\n",
      "17/17 [==============================] - 1s 36ms/step - loss: 2.6485e-09 - accuracy: 1.0000 - val_loss: 3.2851 - val_accuracy: 0.8329\n",
      "Epoch 75/100\n",
      "17/17 [==============================] - 1s 31ms/step - loss: 2.7955e-09 - accuracy: 1.0000 - val_loss: 3.3021 - val_accuracy: 0.8343\n",
      "Epoch 76/100\n",
      "17/17 [==============================] - 1s 30ms/step - loss: 1.6577e-09 - accuracy: 1.0000 - val_loss: 3.3525 - val_accuracy: 0.8329\n",
      "Epoch 77/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 1.8567e-09 - accuracy: 1.0000 - val_loss: 3.3681 - val_accuracy: 0.8343\n",
      "Epoch 78/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 1.3122e-09 - accuracy: 1.0000 - val_loss: 3.4063 - val_accuracy: 0.8329\n",
      "Epoch 79/100\n",
      "17/17 [==============================] - 0s 24ms/step - loss: 2.1913e-09 - accuracy: 1.0000 - val_loss: 3.4534 - val_accuracy: 0.8329\n",
      "Epoch 80/100\n",
      "17/17 [==============================] - 1s 33ms/step - loss: 8.2214e-10 - accuracy: 1.0000 - val_loss: 3.4914 - val_accuracy: 0.8314\n",
      "Epoch 81/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 8.7293e-10 - accuracy: 1.0000 - val_loss: 3.5333 - val_accuracy: 0.8329\n",
      "Epoch 82/100\n",
      "17/17 [==============================] - 0s 18ms/step - loss: 7.6187e-10 - accuracy: 1.0000 - val_loss: 3.5691 - val_accuracy: 0.8314\n",
      "Epoch 83/100\n",
      "17/17 [==============================] - 0s 28ms/step - loss: 3.3811e-10 - accuracy: 1.0000 - val_loss: 3.5936 - val_accuracy: 0.8329\n",
      "Epoch 84/100\n",
      "17/17 [==============================] - 1s 37ms/step - loss: 8.5185e-10 - accuracy: 1.0000 - val_loss: 3.6297 - val_accuracy: 0.8314\n",
      "Epoch 85/100\n",
      "17/17 [==============================] - 1s 37ms/step - loss: 2.8228e-10 - accuracy: 1.0000 - val_loss: 3.6437 - val_accuracy: 0.8300\n",
      "Epoch 86/100\n",
      "17/17 [==============================] - 0s 27ms/step - loss: 4.2426e-10 - accuracy: 1.0000 - val_loss: 3.6538 - val_accuracy: 0.8314\n",
      "Epoch 87/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 1.2045e-10 - accuracy: 1.0000 - val_loss: 3.6806 - val_accuracy: 0.8286\n",
      "Epoch 88/100\n",
      "17/17 [==============================] - 0s 22ms/step - loss: 6.4203e-11 - accuracy: 1.0000 - val_loss: 3.6964 - val_accuracy: 0.8271\n",
      "Epoch 89/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 4.1939e-10 - accuracy: 1.0000 - val_loss: 3.7018 - val_accuracy: 0.8271\n",
      "Epoch 90/100\n",
      "17/17 [==============================] - 0s 23ms/step - loss: 1.7513e-10 - accuracy: 1.0000 - val_loss: 3.7258 - val_accuracy: 0.8271\n",
      "Epoch 91/100\n",
      "17/17 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 3.7429 - val_accuracy: 0.8271\n",
      "Epoch 92/100\n",
      "17/17 [==============================] - 0s 29ms/step - loss: 8.0794e-11 - accuracy: 1.0000 - val_loss: 3.7550 - val_accuracy: 0.8257\n",
      "Epoch 93/100\n",
      "17/17 [==============================] - 0s 24ms/step - loss: 2.4327e-10 - accuracy: 1.0000 - val_loss: 3.7616 - val_accuracy: 0.8271\n",
      "Epoch 94/100\n",
      "17/17 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 3.7754 - val_accuracy: 0.8257\n",
      "Epoch 95/100\n",
      "17/17 [==============================] - 0s 20ms/step - loss: 1.6670e-11 - accuracy: 1.0000 - val_loss: 3.7831 - val_accuracy: 0.8257\n",
      "Epoch 96/100\n",
      "17/17 [==============================] - 0s 23ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 3.7904 - val_accuracy: 0.8257\n",
      "Epoch 97/100\n",
      "17/17 [==============================] - 0s 26ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 3.7989 - val_accuracy: 0.8243\n",
      "Epoch 98/100\n",
      "17/17 [==============================] - 0s 25ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 3.8067 - val_accuracy: 0.8243\n",
      "Epoch 99/100\n",
      "17/17 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 3.8134 - val_accuracy: 0.8243\n",
      "Epoch 100/100\n",
      "17/17 [==============================] - 0s 28ms/step - loss: 8.9316e-11 - accuracy: 1.0000 - val_loss: 3.8167 - val_accuracy: 0.8257\n"
     ]
    }
   ],
   "source": [
    "## I now fit the model, and store the training history\n",
    "## I use 100 epochs and a batch_size of 512\n",
    "history = model.fit(X_train,\n",
    "                        to_categorical(y_train),\n",
    "                        epochs = 100,\n",
    "                        batch_size = 100,\n",
    "                        validation_data=(X_test,to_categorical(y_test)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_dict = history.history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x432 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "set_style(\"whitegrid\")\n",
    "\n",
    "plt.figure(figsize = (10,6))\n",
    "\n",
    "plt.scatter(range(1,101), history_dict['accuracy'], label = \"Training Accuracy\")\n",
    "plt.scatter(range(1,101), history_dict['val_accuracy'], label = \"Validation Set Accuracy\")\n",
    "\n",
    "plt.xlabel(\"Epoch\", fontsize=18)\n",
    "plt.ylabel(\"Accuracy\", fontsize=18)\n",
    "\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "\n",
    "plt.legend(fontsize=18)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x432 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize = (10,6))\n",
    "\n",
    "plt.scatter(range(1,101), history_dict['loss'], label = \"Training Loss\")\n",
    "plt.scatter(range(1,101), history_dict['val_loss'], label = \"Validation Set Loss\")\n",
    "\n",
    "plt.xlabel(\"Epoch\", fontsize=18)\n",
    "plt.ylabel(\"Loss Function Value\", fontsize=18)\n",
    "\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "\n",
    "plt.legend(fontsize=18)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "fpr,tpr,thresholds=roc_curve(y_test,model.predict(X_test)[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x432 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize = (10,6))\n",
    "\n",
    "plt.scatter(fpr, tpr ,label = \"ROC curve\")\n",
    "\n",
    "plt.xlabel(\"fpr\", fontsize=18)\n",
    "plt.ylabel(\"tpr\", fontsize=18)\n",
    "\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "\n",
    "plt.legend(fontsize=18)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DNN for fft data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_21\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_84 (Dense)             (None, 256)               1920256   \n",
      "_________________________________________________________________\n",
      "dense_85 (Dense)             (None, 256)               65792     \n",
      "_________________________________________________________________\n",
      "dense_86 (Dense)             (None, 256)               65792     \n",
      "_________________________________________________________________\n",
      "dense_87 (Dense)             (None, 2)                 514       \n",
      "=================================================================\n",
      "Total params: 2,052,354\n",
      "Trainable params: 2,052,354\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model_fft = models.Sequential()\n",
    "model_fft.add(layers.Dense(256, activation='relu', input_shape=(7500,)))\n",
    "model_fft.add(layers.Dense(256, activation='relu'))\n",
    "model_fft.add(layers.Dense(256, activation='relu'))\n",
    "model_fft.add(layers.Dense(2, activation='softmax'))\n",
    "model_fft.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we compile the network like so\n",
    "model_fft.compile(optimizer='adam',\n",
    "                  loss='categorical_crossentropy',\n",
    "                  metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "5/5 [==============================] - 1s 186ms/step - loss: 668.8413 - accuracy: 0.5319 - val_loss: 44.0859 - val_accuracy: 0.5014\n",
      "Epoch 2/100\n",
      "5/5 [==============================] - 1s 123ms/step - loss: 270.2084 - accuracy: 0.5108 - val_loss: 569.3049 - val_accuracy: 0.5100\n",
      "Epoch 3/100\n",
      "5/5 [==============================] - 1s 131ms/step - loss: 365.4353 - accuracy: 0.5163 - val_loss: 60.2744 - val_accuracy: 0.5100\n",
      "Epoch 4/100\n",
      "5/5 [==============================] - 1s 123ms/step - loss: 82.9828 - accuracy: 0.5218 - val_loss: 25.7717 - val_accuracy: 0.5514\n",
      "Epoch 5/100\n",
      "5/5 [==============================] - 1s 131ms/step - loss: 93.3083 - accuracy: 0.5303 - val_loss: 233.1957 - val_accuracy: 0.5100\n",
      "Epoch 6/100\n",
      "5/5 [==============================] - 0s 106ms/step - loss: 177.3547 - accuracy: 0.5145 - val_loss: 175.4946 - val_accuracy: 0.4900\n",
      "Epoch 7/100\n",
      "5/5 [==============================] - 1s 117ms/step - loss: 125.8578 - accuracy: 0.5587 - val_loss: 36.0824 - val_accuracy: 0.5171\n",
      "Epoch 8/100\n",
      "5/5 [==============================] - 0s 100ms/step - loss: 86.9217 - accuracy: 0.4990 - val_loss: 67.8607 - val_accuracy: 0.5086\n",
      "Epoch 9/100\n",
      "5/5 [==============================] - 0s 101ms/step - loss: 77.3385 - accuracy: 0.5346 - val_loss: 84.1572 - val_accuracy: 0.4971\n",
      "Epoch 10/100\n",
      "5/5 [==============================] - 1s 160ms/step - loss: 55.2668 - accuracy: 0.5356 - val_loss: 22.9971 - val_accuracy: 0.6771\n",
      "Epoch 11/100\n",
      "5/5 [==============================] - 1s 148ms/step - loss: 25.1741 - accuracy: 0.6633 - val_loss: 16.0151 - val_accuracy: 0.6743\n",
      "Epoch 12/100\n",
      "5/5 [==============================] - 1s 130ms/step - loss: 15.4722 - accuracy: 0.6968 - val_loss: 10.9474 - val_accuracy: 0.7114\n",
      "Epoch 13/100\n",
      "5/5 [==============================] - 1s 127ms/step - loss: 9.7629 - accuracy: 0.7271 - val_loss: 7.0434 - val_accuracy: 0.7429\n",
      "Epoch 14/100\n",
      "5/5 [==============================] - 1s 178ms/step - loss: 7.3605 - accuracy: 0.7383 - val_loss: 5.8756 - val_accuracy: 0.7329\n",
      "Epoch 15/100\n",
      "5/5 [==============================] - 1s 152ms/step - loss: 6.1939 - accuracy: 0.7216 - val_loss: 3.7403 - val_accuracy: 0.7371\n",
      "Epoch 16/100\n",
      "5/5 [==============================] - 1s 148ms/step - loss: 4.5153 - accuracy: 0.7219 - val_loss: 9.5329 - val_accuracy: 0.6243\n",
      "Epoch 17/100\n",
      "5/5 [==============================] - 1s 167ms/step - loss: 7.2322 - accuracy: 0.6801 - val_loss: 4.3928 - val_accuracy: 0.7571\n",
      "Epoch 18/100\n",
      "5/5 [==============================] - 1s 112ms/step - loss: 4.6078 - accuracy: 0.7660 - val_loss: 5.8847 - val_accuracy: 0.7343\n",
      "Epoch 19/100\n",
      "5/5 [==============================] - 1s 121ms/step - loss: 4.8300 - accuracy: 0.7732 - val_loss: 10.0567 - val_accuracy: 0.6243\n",
      "Epoch 20/100\n",
      "5/5 [==============================] - 1s 107ms/step - loss: 7.2900 - accuracy: 0.6803 - val_loss: 3.7964 - val_accuracy: 0.7514\n",
      "Epoch 21/100\n",
      "5/5 [==============================] - 1s 161ms/step - loss: 3.6784 - accuracy: 0.7823 - val_loss: 2.9960 - val_accuracy: 0.8100\n",
      "Epoch 22/100\n",
      "5/5 [==============================] - 1s 144ms/step - loss: 3.4123 - accuracy: 0.7908 - val_loss: 2.5648 - val_accuracy: 0.8029\n",
      "Epoch 23/100\n",
      "5/5 [==============================] - 1s 145ms/step - loss: 2.2857 - accuracy: 0.8270 - val_loss: 5.0830 - val_accuracy: 0.6671\n",
      "Epoch 24/100\n",
      "5/5 [==============================] - 1s 154ms/step - loss: 4.2492 - accuracy: 0.7031 - val_loss: 4.2682 - val_accuracy: 0.7129\n",
      "Epoch 25/100\n",
      "5/5 [==============================] - 1s 149ms/step - loss: 2.6757 - accuracy: 0.7756 - val_loss: 6.0220 - val_accuracy: 0.6643\n",
      "Epoch 26/100\n",
      "5/5 [==============================] - 1s 112ms/step - loss: 4.1686 - accuracy: 0.7316 - val_loss: 8.2055 - val_accuracy: 0.6429\n",
      "Epoch 27/100\n",
      "5/5 [==============================] - 1s 153ms/step - loss: 5.5506 - accuracy: 0.7262 - val_loss: 8.7718 - val_accuracy: 0.6386\n",
      "Epoch 28/100\n",
      "5/5 [==============================] - 1s 149ms/step - loss: 5.4512 - accuracy: 0.7161 - val_loss: 8.9510 - val_accuracy: 0.6500\n",
      "Epoch 29/100\n",
      "5/5 [==============================] - 1s 165ms/step - loss: 5.9735 - accuracy: 0.7098 - val_loss: 3.6430 - val_accuracy: 0.7700\n",
      "Epoch 30/100\n",
      "5/5 [==============================] - 1s 119ms/step - loss: 3.3268 - accuracy: 0.7875 - val_loss: 2.8525 - val_accuracy: 0.7829\n",
      "Epoch 31/100\n",
      "5/5 [==============================] - 1s 124ms/step - loss: 1.8482 - accuracy: 0.8250 - val_loss: 1.8496 - val_accuracy: 0.8071\n",
      "Epoch 32/100\n",
      "5/5 [==============================] - 1s 118ms/step - loss: 0.8518 - accuracy: 0.8750 - val_loss: 2.1047 - val_accuracy: 0.7714\n",
      "Epoch 33/100\n",
      "5/5 [==============================] - 1s 121ms/step - loss: 0.9209 - accuracy: 0.8615 - val_loss: 2.8018 - val_accuracy: 0.7429\n",
      "Epoch 34/100\n",
      "5/5 [==============================] - 1s 135ms/step - loss: 1.2811 - accuracy: 0.8451 - val_loss: 2.0684 - val_accuracy: 0.8157\n",
      "Epoch 35/100\n",
      "5/5 [==============================] - 1s 178ms/step - loss: 1.0505 - accuracy: 0.8781 - val_loss: 1.4359 - val_accuracy: 0.8500\n",
      "Epoch 36/100\n",
      "5/5 [==============================] - 1s 116ms/step - loss: 0.6192 - accuracy: 0.8989 - val_loss: 1.4616 - val_accuracy: 0.8229\n",
      "Epoch 37/100\n",
      "5/5 [==============================] - 1s 135ms/step - loss: 0.5381 - accuracy: 0.8999 - val_loss: 1.8011 - val_accuracy: 0.8114\n",
      "Epoch 38/100\n",
      "5/5 [==============================] - 1s 162ms/step - loss: 0.7226 - accuracy: 0.9003 - val_loss: 1.3293 - val_accuracy: 0.8371\n",
      "Epoch 39/100\n",
      "5/5 [==============================] - 1s 152ms/step - loss: 0.4180 - accuracy: 0.9221 - val_loss: 1.4263 - val_accuracy: 0.8429\n",
      "Epoch 40/100\n",
      "5/5 [==============================] - 1s 177ms/step - loss: 0.4432 - accuracy: 0.9168 - val_loss: 1.4712 - val_accuracy: 0.8214\n",
      "Epoch 41/100\n",
      "5/5 [==============================] - 1s 144ms/step - loss: 0.4285 - accuracy: 0.9065 - val_loss: 1.3988 - val_accuracy: 0.8443\n",
      "Epoch 42/100\n",
      "5/5 [==============================] - 1s 152ms/step - loss: 0.3796 - accuracy: 0.9174 - val_loss: 1.5349 - val_accuracy: 0.8314\n",
      "Epoch 43/100\n",
      "5/5 [==============================] - 1s 135ms/step - loss: 0.4479 - accuracy: 0.9121 - val_loss: 1.3893 - val_accuracy: 0.8557\n",
      "Epoch 44/100\n",
      "5/5 [==============================] - 1s 205ms/step - loss: 0.2786 - accuracy: 0.9392 - val_loss: 3.3798 - val_accuracy: 0.7086\n",
      "Epoch 45/100\n",
      "5/5 [==============================] - 1s 189ms/step - loss: 1.1848 - accuracy: 0.8274 - val_loss: 1.5555 - val_accuracy: 0.8371\n",
      "Epoch 46/100\n",
      "5/5 [==============================] - 1s 169ms/step - loss: 0.5219 - accuracy: 0.9136 - val_loss: 1.4113 - val_accuracy: 0.8371\n",
      "Epoch 47/100\n",
      "5/5 [==============================] - 1s 154ms/step - loss: 0.3574 - accuracy: 0.9249 - val_loss: 1.5458 - val_accuracy: 0.8357\n",
      "Epoch 48/100\n",
      "5/5 [==============================] - 1s 124ms/step - loss: 0.3093 - accuracy: 0.9423 - val_loss: 1.4719 - val_accuracy: 0.8314\n",
      "Epoch 49/100\n",
      "5/5 [==============================] - 1s 151ms/step - loss: 0.3219 - accuracy: 0.9313 - val_loss: 1.2659 - val_accuracy: 0.8629\n",
      "Epoch 50/100\n",
      "5/5 [==============================] - 1s 171ms/step - loss: 0.2089 - accuracy: 0.9521 - val_loss: 2.2221 - val_accuracy: 0.7771\n",
      "Epoch 51/100\n",
      "5/5 [==============================] - 1s 139ms/step - loss: 0.5970 - accuracy: 0.8885 - val_loss: 1.6676 - val_accuracy: 0.8214\n",
      "Epoch 52/100\n",
      "5/5 [==============================] - 1s 169ms/step - loss: 0.2218 - accuracy: 0.9367 - val_loss: 2.9995 - val_accuracy: 0.7414\n",
      "Epoch 53/100\n",
      "5/5 [==============================] - 1s 173ms/step - loss: 0.8647 - accuracy: 0.8692 - val_loss: 1.4096 - val_accuracy: 0.8571\n",
      "Epoch 54/100\n",
      "5/5 [==============================] - 1s 201ms/step - loss: 0.4277 - accuracy: 0.9210 - val_loss: 1.3716 - val_accuracy: 0.8557\n",
      "Epoch 55/100\n",
      "5/5 [==============================] - 1s 155ms/step - loss: 0.1871 - accuracy: 0.9521 - val_loss: 1.4880 - val_accuracy: 0.8286\n",
      "Epoch 56/100\n",
      "5/5 [==============================] - 1s 139ms/step - loss: 0.2525 - accuracy: 0.9379 - val_loss: 1.2095 - val_accuracy: 0.8571\n",
      "Epoch 57/100\n",
      "5/5 [==============================] - 1s 200ms/step - loss: 0.1593 - accuracy: 0.9504 - val_loss: 1.1502 - val_accuracy: 0.8786\n",
      "Epoch 58/100\n",
      "5/5 [==============================] - 1s 169ms/step - loss: 0.1301 - accuracy: 0.9653 - val_loss: 1.2333 - val_accuracy: 0.8686\n",
      "Epoch 59/100\n",
      "5/5 [==============================] - 1s 155ms/step - loss: 0.1333 - accuracy: 0.9599 - val_loss: 1.1907 - val_accuracy: 0.8771\n",
      "Epoch 60/100\n",
      "5/5 [==============================] - 1s 137ms/step - loss: 0.1730 - accuracy: 0.9557 - val_loss: 1.5155 - val_accuracy: 0.8314\n",
      "Epoch 61/100\n",
      "5/5 [==============================] - 1s 177ms/step - loss: 0.1984 - accuracy: 0.9503 - val_loss: 1.5048 - val_accuracy: 0.8414\n",
      "Epoch 62/100\n",
      "5/5 [==============================] - 1s 157ms/step - loss: 0.1791 - accuracy: 0.9458 - val_loss: 2.5472 - val_accuracy: 0.7543\n",
      "Epoch 63/100\n",
      "5/5 [==============================] - 1s 143ms/step - loss: 0.7222 - accuracy: 0.8597 - val_loss: 1.9411 - val_accuracy: 0.8214\n",
      "Epoch 64/100\n",
      "5/5 [==============================] - 1s 136ms/step - loss: 0.7179 - accuracy: 0.8668 - val_loss: 1.3057 - val_accuracy: 0.8871\n",
      "Epoch 65/100\n",
      "5/5 [==============================] - 1s 143ms/step - loss: 0.3293 - accuracy: 0.9339 - val_loss: 2.7477 - val_accuracy: 0.7557\n",
      "Epoch 66/100\n",
      "5/5 [==============================] - 1s 137ms/step - loss: 1.2682 - accuracy: 0.8158 - val_loss: 3.8205 - val_accuracy: 0.7400\n",
      "Epoch 67/100\n",
      "5/5 [==============================] - 1s 182ms/step - loss: 1.3911 - accuracy: 0.8621 - val_loss: 2.7817 - val_accuracy: 0.7714\n",
      "Epoch 68/100\n",
      "5/5 [==============================] - 1s 182ms/step - loss: 0.9739 - accuracy: 0.8762 - val_loss: 3.5076 - val_accuracy: 0.7186\n",
      "Epoch 69/100\n",
      "5/5 [==============================] - 1s 142ms/step - loss: 0.9855 - accuracy: 0.8385 - val_loss: 3.3234 - val_accuracy: 0.7571\n",
      "Epoch 70/100\n",
      "5/5 [==============================] - 1s 144ms/step - loss: 0.7558 - accuracy: 0.8736 - val_loss: 1.9545 - val_accuracy: 0.8386\n",
      "Epoch 71/100\n",
      "5/5 [==============================] - 1s 150ms/step - loss: 0.3468 - accuracy: 0.9257 - val_loss: 1.5358 - val_accuracy: 0.8671\n",
      "Epoch 72/100\n",
      "5/5 [==============================] - 1s 169ms/step - loss: 0.1140 - accuracy: 0.9723 - val_loss: 1.7201 - val_accuracy: 0.8300\n",
      "Epoch 73/100\n",
      "5/5 [==============================] - 1s 139ms/step - loss: 0.1078 - accuracy: 0.9663 - val_loss: 1.5763 - val_accuracy: 0.8529\n",
      "Epoch 74/100\n",
      "5/5 [==============================] - 1s 157ms/step - loss: 0.2788 - accuracy: 0.9436 - val_loss: 1.2660 - val_accuracy: 0.8900\n",
      "Epoch 75/100\n",
      "5/5 [==============================] - 1s 186ms/step - loss: 0.0831 - accuracy: 0.9695 - val_loss: 1.7341 - val_accuracy: 0.8443\n",
      "Epoch 76/100\n",
      "5/5 [==============================] - 1s 172ms/step - loss: 0.2608 - accuracy: 0.9517 - val_loss: 1.3608 - val_accuracy: 0.8857\n",
      "Epoch 77/100\n",
      "5/5 [==============================] - 1s 225ms/step - loss: 0.1023 - accuracy: 0.9766 - val_loss: 2.0198 - val_accuracy: 0.8271\n",
      "Epoch 78/100\n",
      "5/5 [==============================] - 1s 168ms/step - loss: 0.1277 - accuracy: 0.9648 - val_loss: 1.3196 - val_accuracy: 0.8786\n",
      "Epoch 79/100\n",
      "5/5 [==============================] - 1s 157ms/step - loss: 0.0739 - accuracy: 0.9771 - val_loss: 1.4343 - val_accuracy: 0.8871\n",
      "Epoch 80/100\n",
      "5/5 [==============================] - 1s 169ms/step - loss: 0.0498 - accuracy: 0.9869 - val_loss: 1.6975 - val_accuracy: 0.8414\n",
      "Epoch 81/100\n",
      "5/5 [==============================] - 1s 151ms/step - loss: 0.0372 - accuracy: 0.9883 - val_loss: 1.3767 - val_accuracy: 0.8743\n",
      "Epoch 82/100\n",
      "5/5 [==============================] - 1s 153ms/step - loss: 0.0456 - accuracy: 0.9839 - val_loss: 1.3798 - val_accuracy: 0.8643\n",
      "Epoch 83/100\n",
      "5/5 [==============================] - 1s 135ms/step - loss: 0.0134 - accuracy: 0.9930 - val_loss: 1.2891 - val_accuracy: 0.8829\n",
      "Epoch 84/100\n",
      "5/5 [==============================] - 1s 155ms/step - loss: 0.0046 - accuracy: 0.9993 - val_loss: 1.2904 - val_accuracy: 0.8814\n",
      "Epoch 85/100\n",
      "5/5 [==============================] - 1s 136ms/step - loss: 0.0074 - accuracy: 0.9971 - val_loss: 1.3100 - val_accuracy: 0.8800\n",
      "Epoch 86/100\n",
      "5/5 [==============================] - 1s 126ms/step - loss: 0.0063 - accuracy: 0.9985 - val_loss: 1.3457 - val_accuracy: 0.8800\n",
      "Epoch 87/100\n",
      "5/5 [==============================] - 1s 151ms/step - loss: 0.0038 - accuracy: 1.0000 - val_loss: 1.2894 - val_accuracy: 0.8814\n",
      "Epoch 88/100\n",
      "5/5 [==============================] - 1s 156ms/step - loss: 0.0030 - accuracy: 0.9987 - val_loss: 1.3312 - val_accuracy: 0.8743\n",
      "Epoch 89/100\n",
      "5/5 [==============================] - 1s 164ms/step - loss: 0.0027 - accuracy: 1.0000 - val_loss: 1.3133 - val_accuracy: 0.8743\n",
      "Epoch 90/100\n",
      "5/5 [==============================] - 1s 164ms/step - loss: 0.0021 - accuracy: 1.0000 - val_loss: 1.2968 - val_accuracy: 0.8771\n",
      "Epoch 91/100\n",
      "5/5 [==============================] - 1s 158ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 1.3257 - val_accuracy: 0.8757\n",
      "Epoch 92/100\n",
      "5/5 [==============================] - 1s 169ms/step - loss: 0.0019 - accuracy: 1.0000 - val_loss: 1.3199 - val_accuracy: 0.8757\n",
      "Epoch 93/100\n",
      "5/5 [==============================] - 1s 171ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 1.2915 - val_accuracy: 0.8814\n",
      "Epoch 94/100\n",
      "5/5 [==============================] - 1s 251ms/step - loss: 0.0017 - accuracy: 1.0000 - val_loss: 1.3020 - val_accuracy: 0.8800\n",
      "Epoch 95/100\n",
      "5/5 [==============================] - 1s 225ms/step - loss: 0.0013 - accuracy: 1.0000 - val_loss: 1.3131 - val_accuracy: 0.8786\n",
      "Epoch 96/100\n",
      "5/5 [==============================] - 1s 149ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 1.2945 - val_accuracy: 0.8800\n",
      "Epoch 97/100\n",
      "5/5 [==============================] - 1s 167ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 1.3119 - val_accuracy: 0.8843\n",
      "Epoch 98/100\n",
      "5/5 [==============================] - 1s 157ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 1.3006 - val_accuracy: 0.8814\n",
      "Epoch 99/100\n",
      "5/5 [==============================] - 1s 166ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 1.3015 - val_accuracy: 0.8814\n",
      "Epoch 100/100\n",
      "5/5 [==============================] - 1s 141ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 1.3061 - val_accuracy: 0.8829\n"
     ]
    }
   ],
   "source": [
    "## I now fit the model, and store the training history\n",
    "## I use 100 epochs and a batch_size of 512\n",
    "history_fft = model_fft.fit(X_fft_train,\n",
    "                        to_categorical(y_fft_train),\n",
    "                        epochs = 100,\n",
    "                        batch_size = 400,\n",
    "                        validation_data=(X_fft_test,to_categorical(y_fft_test)))\n",
    "history_fft_dict = history_fft.history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Naive Bayes Gaussian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 208,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6757142857142857\n",
      "0.7663934426229508\n",
      "0.5238095238095238\n"
     ]
    }
   ],
   "source": [
    "from sklearn.naive_bayes import GaussianNB\n",
    "model_gnb = GaussianNB()\n",
    "model_gnb.fit(X_train, y_train);\n",
    "y_pred_gnb = model_gnb.predict(X_test)\n",
    "print( accuracy_score(y_test,y_pred_gnb))\n",
    "print( precision_score(y_test,y_pred_gnb))\n",
    "print( recall_score(y_test,y_pred_gnb))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 209,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6485714285714286\n",
      "0.8135593220338984\n",
      "0.40336134453781514\n"
     ]
    }
   ],
   "source": [
    "from sklearn.naive_bayes import GaussianNB\n",
    "model_gnb = GaussianNB()\n",
    "model_gnb.fit(X_fft_train, y_fft_train);\n",
    "y_fft_pred_gnb = model_gnb.predict(X_fft_test)\n",
    "print( accuracy_score(y_fft_test,y_fft_pred_gnb))\n",
    "print( precision_score(y_fft_test,y_fft_pred_gnb))\n",
    "print( recall_score(y_fft_test,y_fft_pred_gnb))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7685714285714286\n",
      "0.7545691906005222\n",
      "0.8095238095238095\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "log_reg = LogisticRegression(max_iter=100000)\n",
    "log_reg.fit(X_train,y_train)\n",
    "y_pred_log_reg = log_reg.predict(X_test)\n",
    "print( accuracy_score(y_test,y_pred_log_reg))\n",
    "print( precision_score(y_test,y_pred_log_reg))\n",
    "print( recall_score(y_test,y_pred_log_reg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8614285714285714\n",
      "0.8282828282828283\n",
      "0.9187675070028011\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "log_reg = LogisticRegression(max_iter=100000)\n",
    "log_reg.fit(X_fft_train[:,:],y_fft_train)\n",
    "y_fft_pred_log_reg = log_reg.predict(X_fft_test[:,:])\n",
    "print( accuracy_score(y_fft_test,y_fft_pred_log_reg))\n",
    "print( precision_score(y_fft_test,y_fft_pred_log_reg))\n",
    "print( recall_score(y_fft_test,y_fft_pred_log_reg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}