{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Stress Detection Using ML" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reading the Dataset" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
net_acc_meannet_acc_stdnet_acc_minnet_acc_maxACC_x_meanACC_x_stdACC_x_minACC_x_maxACC_y_meanACC_y_std...ageheightweightgender_ femalegender_ malecoffee_today_YESsport_today_YESsmoker_NOsmoker_YESfeel_ill_today_YES
00.0299370.0099420.0000000.0873830.0295100.011145-0.0240820.0873830.0000200.000008...27175800100100
10.0219860.0158450.0000000.0715580.0173520.020817-0.0378430.0715580.0000120.000014...27175800100100
20.0208390.0110340.0027520.0543560.0208390.0110340.0027520.0543560.0000140.000008...27175800100100
30.0344490.0031850.0137610.0405950.0344490.0031850.0137610.0405950.0000240.000002...27175800100100
\n", "

4 rows × 58 columns

\n", "
" ], "text/plain": [ " net_acc_mean net_acc_std net_acc_min net_acc_max ACC_x_mean ACC_x_std \\\n", "0 0.029937 0.009942 0.000000 0.087383 0.029510 0.011145 \n", "1 0.021986 0.015845 0.000000 0.071558 0.017352 0.020817 \n", "2 0.020839 0.011034 0.002752 0.054356 0.020839 0.011034 \n", "3 0.034449 0.003185 0.013761 0.040595 0.034449 0.003185 \n", "\n", " ACC_x_min ACC_x_max ACC_y_mean ACC_y_std ... age height weight \\\n", "0 -0.024082 0.087383 0.000020 0.000008 ... 27 175 80 \n", "1 -0.037843 0.071558 0.000012 0.000014 ... 27 175 80 \n", "2 0.002752 0.054356 0.000014 0.000008 ... 27 175 80 \n", "3 0.013761 0.040595 0.000024 0.000002 ... 27 175 80 \n", "\n", " gender_ female gender_ male coffee_today_YES sport_today_YES smoker_NO \\\n", "0 0 1 0 0 1 \n", "1 0 1 0 0 1 \n", "2 0 1 0 0 1 \n", "3 0 1 0 0 1 \n", "\n", " smoker_YES feel_ill_today_YES \n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "\n", "[4 rows x 58 columns]" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('data/merged.csv', index_col=0)\n", "df.head(4)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['net_acc_mean', 'net_acc_std', 'net_acc_min', 'net_acc_max',\n", " 'ACC_x_mean', 'ACC_x_std', 'ACC_x_min', 'ACC_x_max', 'ACC_y_mean',\n", " 'ACC_y_std', 'ACC_y_min', 'ACC_y_max', 'ACC_z_mean', 'ACC_z_std',\n", " 'ACC_z_min', 'ACC_z_max', 'BVP_mean', 'BVP_std', 'BVP_min', 'BVP_max',\n", " 'EDA_mean', 'EDA_std', 'EDA_min', 'EDA_max', 'EDA_phasic_mean',\n", " 'EDA_phasic_std', 'EDA_phasic_min', 'EDA_phasic_max', 'EDA_smna_mean',\n", " 'EDA_smna_std', 'EDA_smna_min', 'EDA_smna_max', 'EDA_tonic_mean',\n", " 'EDA_tonic_std', 'EDA_tonic_min', 'EDA_tonic_max', 'Resp_mean',\n", " 'Resp_std', 'Resp_min', 'Resp_max', 'TEMP_mean', 'TEMP_std', 'TEMP_min',\n", " 'TEMP_max', 'BVP_peak_freq', 'TEMP_slope', 'subject', 'label', 'age',\n", " 'height', 'weight', 'gender_ female', 'gender_ male',\n", " 'coffee_today_YES', 'sport_today_YES', 'smoker_NO', 'smoker_YES',\n", " 'feel_ill_today_YES'],\n", " dtype='object')" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 1, 2])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.unique(df['label'])" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "labels = {\n", " 0: \"Amused\",\n", " 1: \"Neutral\",\n", " 2: \"Stressed\"\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Feature Selection" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "nbconvert_remove" ] }, "outputs": [], "source": [ "# nbconvert_remove\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "plt.figure(figsize=(2,100))\n", "cor = df.corr()\n", "n_targets = len(df.columns)\n", "cor_target = cor['label'].values.reshape(n_targets, 1)\n", "cor_features = cor['label'].keys()\n", "ax = sns.heatmap(cor_target, annot=True, cmap=plt.cm.Accent_r)\n", "ax.set_yticklabels(cor_features)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1178, 15), (1178,))" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "selected_feats = [\n", " 'BVP_mean', 'BVP_std', 'EDA_phasic_mean', 'EDA_phasic_min', 'EDA_smna_min', \n", " 'EDA_tonic_mean', 'Resp_mean', 'Resp_std', 'TEMP_mean', 'TEMP_std', 'TEMP_slope',\n", " 'BVP_peak_freq', 'age', 'height', 'weight'\n", " ]\n", "\n", "X = df[selected_feats]\n", "y = df['label']\n", "\n", "X.shape, y.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ML Model" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1060, 15), (118, 15))" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train, X_test = train_test_split(X, test_size=0.1, random_state=0)\n", "y_train, y_test = train_test_split(y, test_size=0.1, random_state=0)\n", "\n", "X_train.shape, X_test.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = RandomForestClassifier()\n", "model.fit(X_train,y_train)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "def accuracy(predicted, actual):\n", " n = 0\n", " for p, a in zip(predicted, actual):\n", " if p == a:\n", " n += 1\n", " return n/len(predicted) * 100" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "def predict(arr):\n", " arr = np.array(arr)\n", "\n", " global model\n", " result = model.predict(arr.reshape(1,-1)).flatten()\n", " # _prob = model.predict_proba(arr.reshape(1,-1)).flatten()\n", " return result" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "tags": [ "nbconvert_remove" ] }, "outputs": [ { "data": { "text/plain": [ "95.76271186440678" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicted = []\n", "for data in X_test.values:\n", " predicted.append(predict(data))\n", "predicted\n", "\n", "accuracy(predicted, y_test.values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving the trained model in a pickle file to be later used by the API function to predict" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "filename = 'trained_model.sav'\n", "pickle.dump(model, open(filename, 'wb'))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.7 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.7" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "b56aae866cfdd3dd1993badfb61811822ff858e2a83b734b90ea6aa544e22f54" } } }, "nbformat": 4, "nbformat_minor": 2 }