[0d34a1]: / model / model.ipynb

Download this file

499 lines (498 with data), 13.6 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stress Detection Using ML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reading the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>net_acc_mean</th>\n",
       "      <th>net_acc_std</th>\n",
       "      <th>net_acc_min</th>\n",
       "      <th>net_acc_max</th>\n",
       "      <th>ACC_x_mean</th>\n",
       "      <th>ACC_x_std</th>\n",
       "      <th>ACC_x_min</th>\n",
       "      <th>ACC_x_max</th>\n",
       "      <th>ACC_y_mean</th>\n",
       "      <th>ACC_y_std</th>\n",
       "      <th>...</th>\n",
       "      <th>age</th>\n",
       "      <th>height</th>\n",
       "      <th>weight</th>\n",
       "      <th>gender_ female</th>\n",
       "      <th>gender_ male</th>\n",
       "      <th>coffee_today_YES</th>\n",
       "      <th>sport_today_YES</th>\n",
       "      <th>smoker_NO</th>\n",
       "      <th>smoker_YES</th>\n",
       "      <th>feel_ill_today_YES</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.029937</td>\n",
       "      <td>0.009942</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.087383</td>\n",
       "      <td>0.029510</td>\n",
       "      <td>0.011145</td>\n",
       "      <td>-0.024082</td>\n",
       "      <td>0.087383</td>\n",
       "      <td>0.000020</td>\n",
       "      <td>0.000008</td>\n",
       "      <td>...</td>\n",
       "      <td>27</td>\n",
       "      <td>175</td>\n",
       "      <td>80</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.021986</td>\n",
       "      <td>0.015845</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.071558</td>\n",
       "      <td>0.017352</td>\n",
       "      <td>0.020817</td>\n",
       "      <td>-0.037843</td>\n",
       "      <td>0.071558</td>\n",
       "      <td>0.000012</td>\n",
       "      <td>0.000014</td>\n",
       "      <td>...</td>\n",
       "      <td>27</td>\n",
       "      <td>175</td>\n",
       "      <td>80</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.020839</td>\n",
       "      <td>0.011034</td>\n",
       "      <td>0.002752</td>\n",
       "      <td>0.054356</td>\n",
       "      <td>0.020839</td>\n",
       "      <td>0.011034</td>\n",
       "      <td>0.002752</td>\n",
       "      <td>0.054356</td>\n",
       "      <td>0.000014</td>\n",
       "      <td>0.000008</td>\n",
       "      <td>...</td>\n",
       "      <td>27</td>\n",
       "      <td>175</td>\n",
       "      <td>80</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.034449</td>\n",
       "      <td>0.003185</td>\n",
       "      <td>0.013761</td>\n",
       "      <td>0.040595</td>\n",
       "      <td>0.034449</td>\n",
       "      <td>0.003185</td>\n",
       "      <td>0.013761</td>\n",
       "      <td>0.040595</td>\n",
       "      <td>0.000024</td>\n",
       "      <td>0.000002</td>\n",
       "      <td>...</td>\n",
       "      <td>27</td>\n",
       "      <td>175</td>\n",
       "      <td>80</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4 rows × 58 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   net_acc_mean  net_acc_std  net_acc_min  net_acc_max  ACC_x_mean  ACC_x_std  \\\n",
       "0      0.029937     0.009942     0.000000     0.087383    0.029510   0.011145   \n",
       "1      0.021986     0.015845     0.000000     0.071558    0.017352   0.020817   \n",
       "2      0.020839     0.011034     0.002752     0.054356    0.020839   0.011034   \n",
       "3      0.034449     0.003185     0.013761     0.040595    0.034449   0.003185   \n",
       "\n",
       "   ACC_x_min  ACC_x_max  ACC_y_mean  ACC_y_std  ...  age  height  weight  \\\n",
       "0  -0.024082   0.087383    0.000020   0.000008  ...   27     175      80   \n",
       "1  -0.037843   0.071558    0.000012   0.000014  ...   27     175      80   \n",
       "2   0.002752   0.054356    0.000014   0.000008  ...   27     175      80   \n",
       "3   0.013761   0.040595    0.000024   0.000002  ...   27     175      80   \n",
       "\n",
       "   gender_ female  gender_ male  coffee_today_YES  sport_today_YES  smoker_NO  \\\n",
       "0               0             1                 0                0          1   \n",
       "1               0             1                 0                0          1   \n",
       "2               0             1                 0                0          1   \n",
       "3               0             1                 0                0          1   \n",
       "\n",
       "   smoker_YES  feel_ill_today_YES  \n",
       "0           0                   0  \n",
       "1           0                   0  \n",
       "2           0                   0  \n",
       "3           0                   0  \n",
       "\n",
       "[4 rows x 58 columns]"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('data/merged.csv', index_col=0)\n",
    "df.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['net_acc_mean', 'net_acc_std', 'net_acc_min', 'net_acc_max',\n",
       "       'ACC_x_mean', 'ACC_x_std', 'ACC_x_min', 'ACC_x_max', 'ACC_y_mean',\n",
       "       'ACC_y_std', 'ACC_y_min', 'ACC_y_max', 'ACC_z_mean', 'ACC_z_std',\n",
       "       'ACC_z_min', 'ACC_z_max', 'BVP_mean', 'BVP_std', 'BVP_min', 'BVP_max',\n",
       "       'EDA_mean', 'EDA_std', 'EDA_min', 'EDA_max', 'EDA_phasic_mean',\n",
       "       'EDA_phasic_std', 'EDA_phasic_min', 'EDA_phasic_max', 'EDA_smna_mean',\n",
       "       'EDA_smna_std', 'EDA_smna_min', 'EDA_smna_max', 'EDA_tonic_mean',\n",
       "       'EDA_tonic_std', 'EDA_tonic_min', 'EDA_tonic_max', 'Resp_mean',\n",
       "       'Resp_std', 'Resp_min', 'Resp_max', 'TEMP_mean', 'TEMP_std', 'TEMP_min',\n",
       "       'TEMP_max', 'BVP_peak_freq', 'TEMP_slope', 'subject', 'label', 'age',\n",
       "       'height', 'weight', 'gender_ female', 'gender_ male',\n",
       "       'coffee_today_YES', 'sport_today_YES', 'smoker_NO', 'smoker_YES',\n",
       "       'feel_ill_today_YES'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(df['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {\n",
    "    0: \"Amused\",\n",
    "    1: \"Neutral\",\n",
    "    2: \"Stressed\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature Selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "nbconvert_remove"
    ]
   },
   "outputs": [],
   "source": [
    "# nbconvert_remove\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=(2,100))\n",
    "cor = df.corr()\n",
    "n_targets = len(df.columns)\n",
    "cor_target = cor['label'].values.reshape(n_targets, 1)\n",
    "cor_features = cor['label'].keys()\n",
    "ax = sns.heatmap(cor_target, annot=True, cmap=plt.cm.Accent_r)\n",
    "ax.set_yticklabels(cor_features)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1178, 15), (1178,))"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "selected_feats =   [\n",
    "    'BVP_mean', 'BVP_std', 'EDA_phasic_mean', 'EDA_phasic_min', 'EDA_smna_min', \n",
    "    'EDA_tonic_mean', 'Resp_mean', 'Resp_std', 'TEMP_mean', 'TEMP_std', 'TEMP_slope',\n",
    "    'BVP_peak_freq', 'age', 'height', 'weight'\n",
    "    ]\n",
    "\n",
    "X = df[selected_feats]\n",
    "y = df['label']\n",
    "\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ML Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1060, 15), (118, 15))"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train, X_test = train_test_split(X, test_size=0.1, random_state=0)\n",
    "y_train, y_test = train_test_split(y, test_size=0.1, random_state=0)\n",
    "\n",
    "X_train.shape, X_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = RandomForestClassifier()\n",
    "model.fit(X_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(predicted, actual):\n",
    "    n = 0\n",
    "    for p, a in zip(predicted, actual):\n",
    "        if p == a:\n",
    "            n += 1\n",
    "    return n/len(predicted) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(arr):\n",
    "    arr = np.array(arr)\n",
    "\n",
    "    global model\n",
    "    result = model.predict(arr.reshape(1,-1)).flatten()\n",
    "    # _prob = model.predict_proba(arr.reshape(1,-1)).flatten()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "tags": [
     "nbconvert_remove"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "95.76271186440678"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted = []\n",
    "for data in X_test.values:\n",
    "    predicted.append(predict(data))\n",
    "predicted\n",
    "\n",
    "accuracy(predicted, y_test.values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving the trained model in a pickle file to be later used by the API function to predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "filename = 'trained_model.sav'\n",
    "pickle.dump(model, open(filename, 'wb'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.7 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b56aae866cfdd3dd1993badfb61811822ff858e2a83b734b90ea6aa544e22f54"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}