--- a +++ b/model/model.ipynb @@ -0,0 +1,498 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Stress Detection Using ML" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reading the Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>net_acc_mean</th>\n", + " <th>net_acc_std</th>\n", + " <th>net_acc_min</th>\n", + " <th>net_acc_max</th>\n", + " <th>ACC_x_mean</th>\n", + " <th>ACC_x_std</th>\n", + " <th>ACC_x_min</th>\n", + " <th>ACC_x_max</th>\n", + " <th>ACC_y_mean</th>\n", + " <th>ACC_y_std</th>\n", + " <th>...</th>\n", + " <th>age</th>\n", + " <th>height</th>\n", + " <th>weight</th>\n", + " <th>gender_ female</th>\n", + " <th>gender_ male</th>\n", + " <th>coffee_today_YES</th>\n", + " <th>sport_today_YES</th>\n", + " <th>smoker_NO</th>\n", + " <th>smoker_YES</th>\n", + " <th>feel_ill_today_YES</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0.029937</td>\n", + " <td>0.009942</td>\n", + " <td>0.000000</td>\n", + " <td>0.087383</td>\n", + " <td>0.029510</td>\n", + " <td>0.011145</td>\n", + " <td>-0.024082</td>\n", + " <td>0.087383</td>\n", + " <td>0.000020</td>\n", + " <td>0.000008</td>\n", + " <td>...</td>\n", + " <td>27</td>\n", + " <td>175</td>\n", + " <td>80</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0.021986</td>\n", + " <td>0.015845</td>\n", + " <td>0.000000</td>\n", + " <td>0.071558</td>\n", + " <td>0.017352</td>\n", + " <td>0.020817</td>\n", + " <td>-0.037843</td>\n", + " <td>0.071558</td>\n", + " <td>0.000012</td>\n", + " <td>0.000014</td>\n", + " <td>...</td>\n", + " <td>27</td>\n", + " <td>175</td>\n", + " <td>80</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0.020839</td>\n", + " <td>0.011034</td>\n", + " <td>0.002752</td>\n", + " <td>0.054356</td>\n", + " <td>0.020839</td>\n", + " <td>0.011034</td>\n", + " <td>0.002752</td>\n", + " <td>0.054356</td>\n", + " <td>0.000014</td>\n", + " <td>0.000008</td>\n", + " <td>...</td>\n", + " <td>27</td>\n", + " <td>175</td>\n", + " <td>80</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0.034449</td>\n", + " <td>0.003185</td>\n", + " <td>0.013761</td>\n", + " <td>0.040595</td>\n", + " <td>0.034449</td>\n", + " <td>0.003185</td>\n", + " <td>0.013761</td>\n", + " <td>0.040595</td>\n", + " <td>0.000024</td>\n", + " <td>0.000002</td>\n", + " <td>...</td>\n", + " <td>27</td>\n", + " <td>175</td>\n", + " <td>80</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>4 rows × 58 columns</p>\n", + "</div>" + ], + "text/plain": [ + " net_acc_mean net_acc_std net_acc_min net_acc_max ACC_x_mean ACC_x_std \\\n", + "0 0.029937 0.009942 0.000000 0.087383 0.029510 0.011145 \n", + "1 0.021986 0.015845 0.000000 0.071558 0.017352 0.020817 \n", + "2 0.020839 0.011034 0.002752 0.054356 0.020839 0.011034 \n", + "3 0.034449 0.003185 0.013761 0.040595 0.034449 0.003185 \n", + "\n", + " ACC_x_min ACC_x_max ACC_y_mean ACC_y_std ... age height weight \\\n", + "0 -0.024082 0.087383 0.000020 0.000008 ... 27 175 80 \n", + "1 -0.037843 0.071558 0.000012 0.000014 ... 27 175 80 \n", + "2 0.002752 0.054356 0.000014 0.000008 ... 27 175 80 \n", + "3 0.013761 0.040595 0.000024 0.000002 ... 27 175 80 \n", + "\n", + " gender_ female gender_ male coffee_today_YES sport_today_YES smoker_NO \\\n", + "0 0 1 0 0 1 \n", + "1 0 1 0 0 1 \n", + "2 0 1 0 0 1 \n", + "3 0 1 0 0 1 \n", + "\n", + " smoker_YES feel_ill_today_YES \n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "\n", + "[4 rows x 58 columns]" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('data/merged.csv', index_col=0)\n", + "df.head(4)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['net_acc_mean', 'net_acc_std', 'net_acc_min', 'net_acc_max',\n", + " 'ACC_x_mean', 'ACC_x_std', 'ACC_x_min', 'ACC_x_max', 'ACC_y_mean',\n", + " 'ACC_y_std', 'ACC_y_min', 'ACC_y_max', 'ACC_z_mean', 'ACC_z_std',\n", + " 'ACC_z_min', 'ACC_z_max', 'BVP_mean', 'BVP_std', 'BVP_min', 'BVP_max',\n", + " 'EDA_mean', 'EDA_std', 'EDA_min', 'EDA_max', 'EDA_phasic_mean',\n", + " 'EDA_phasic_std', 'EDA_phasic_min', 'EDA_phasic_max', 'EDA_smna_mean',\n", + " 'EDA_smna_std', 'EDA_smna_min', 'EDA_smna_max', 'EDA_tonic_mean',\n", + " 'EDA_tonic_std', 'EDA_tonic_min', 'EDA_tonic_max', 'Resp_mean',\n", + " 'Resp_std', 'Resp_min', 'Resp_max', 'TEMP_mean', 'TEMP_std', 'TEMP_min',\n", + " 'TEMP_max', 'BVP_peak_freq', 'TEMP_slope', 'subject', 'label', 'age',\n", + " 'height', 'weight', 'gender_ female', 'gender_ male',\n", + " 'coffee_today_YES', 'sport_today_YES', 'smoker_NO', 'smoker_YES',\n", + " 'feel_ill_today_YES'],\n", + " dtype='object')" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 2])" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(df['label'])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "labels = {\n", + " 0: \"Amused\",\n", + " 1: \"Neutral\",\n", + " 2: \"Stressed\"\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Feature Selection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "nbconvert_remove" + ] + }, + "outputs": [], + "source": [ + "# nbconvert_remove\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "plt.figure(figsize=(2,100))\n", + "cor = df.corr()\n", + "n_targets = len(df.columns)\n", + "cor_target = cor['label'].values.reshape(n_targets, 1)\n", + "cor_features = cor['label'].keys()\n", + "ax = sns.heatmap(cor_target, annot=True, cmap=plt.cm.Accent_r)\n", + "ax.set_yticklabels(cor_features)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((1178, 15), (1178,))" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "selected_feats = [\n", + " 'BVP_mean', 'BVP_std', 'EDA_phasic_mean', 'EDA_phasic_min', 'EDA_smna_min', \n", + " 'EDA_tonic_mean', 'Resp_mean', 'Resp_std', 'TEMP_mean', 'TEMP_std', 'TEMP_slope',\n", + " 'BVP_peak_freq', 'age', 'height', 'weight'\n", + " ]\n", + "\n", + "X = df[selected_feats]\n", + "y = df['label']\n", + "\n", + "X.shape, y.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ML Model" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "from sklearn.ensemble import RandomForestClassifier" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((1060, 15), (118, 15))" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_train, X_test = train_test_split(X, test_size=0.1, random_state=0)\n", + "y_train, y_test = train_test_split(y, test_size=0.1, random_state=0)\n", + "\n", + "X_train.shape, X_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = RandomForestClassifier()\n", + "model.fit(X_train,y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "def accuracy(predicted, actual):\n", + " n = 0\n", + " for p, a in zip(predicted, actual):\n", + " if p == a:\n", + " n += 1\n", + " return n/len(predicted) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "def predict(arr):\n", + " arr = np.array(arr)\n", + "\n", + " global model\n", + " result = model.predict(arr.reshape(1,-1)).flatten()\n", + " # _prob = model.predict_proba(arr.reshape(1,-1)).flatten()\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "tags": [ + "nbconvert_remove" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "95.76271186440678" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predicted = []\n", + "for data in X_test.values:\n", + " predicted.append(predict(data))\n", + "predicted\n", + "\n", + "accuracy(predicted, y_test.values)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Saving the trained model in a pickle file to be later used by the API function to predict" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "filename = 'trained_model.sav'\n", + "pickle.dump(model, open(filename, 'wb'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.7 ('venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.7" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b56aae866cfdd3dd1993badfb61811822ff858e2a83b734b90ea6aa544e22f54" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}