a b/model/model.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "# Stress Detection Using ML"
8
   ]
9
  },
10
  {
11
   "cell_type": "code",
12
   "execution_count": 45,
13
   "metadata": {},
14
   "outputs": [],
15
   "source": [
16
    "import numpy as np\n",
17
    "import pandas as pd"
18
   ]
19
  },
20
  {
21
   "cell_type": "markdown",
22
   "metadata": {},
23
   "source": [
24
    "### Reading the Dataset"
25
   ]
26
  },
27
  {
28
   "cell_type": "code",
29
   "execution_count": 46,
30
   "metadata": {},
31
   "outputs": [
32
    {
33
     "data": {
34
      "text/html": [
35
       "<div>\n",
36
       "<style scoped>\n",
37
       "    .dataframe tbody tr th:only-of-type {\n",
38
       "        vertical-align: middle;\n",
39
       "    }\n",
40
       "\n",
41
       "    .dataframe tbody tr th {\n",
42
       "        vertical-align: top;\n",
43
       "    }\n",
44
       "\n",
45
       "    .dataframe thead th {\n",
46
       "        text-align: right;\n",
47
       "    }\n",
48
       "</style>\n",
49
       "<table border=\"1\" class=\"dataframe\">\n",
50
       "  <thead>\n",
51
       "    <tr style=\"text-align: right;\">\n",
52
       "      <th></th>\n",
53
       "      <th>net_acc_mean</th>\n",
54
       "      <th>net_acc_std</th>\n",
55
       "      <th>net_acc_min</th>\n",
56
       "      <th>net_acc_max</th>\n",
57
       "      <th>ACC_x_mean</th>\n",
58
       "      <th>ACC_x_std</th>\n",
59
       "      <th>ACC_x_min</th>\n",
60
       "      <th>ACC_x_max</th>\n",
61
       "      <th>ACC_y_mean</th>\n",
62
       "      <th>ACC_y_std</th>\n",
63
       "      <th>...</th>\n",
64
       "      <th>age</th>\n",
65
       "      <th>height</th>\n",
66
       "      <th>weight</th>\n",
67
       "      <th>gender_ female</th>\n",
68
       "      <th>gender_ male</th>\n",
69
       "      <th>coffee_today_YES</th>\n",
70
       "      <th>sport_today_YES</th>\n",
71
       "      <th>smoker_NO</th>\n",
72
       "      <th>smoker_YES</th>\n",
73
       "      <th>feel_ill_today_YES</th>\n",
74
       "    </tr>\n",
75
       "  </thead>\n",
76
       "  <tbody>\n",
77
       "    <tr>\n",
78
       "      <th>0</th>\n",
79
       "      <td>0.029937</td>\n",
80
       "      <td>0.009942</td>\n",
81
       "      <td>0.000000</td>\n",
82
       "      <td>0.087383</td>\n",
83
       "      <td>0.029510</td>\n",
84
       "      <td>0.011145</td>\n",
85
       "      <td>-0.024082</td>\n",
86
       "      <td>0.087383</td>\n",
87
       "      <td>0.000020</td>\n",
88
       "      <td>0.000008</td>\n",
89
       "      <td>...</td>\n",
90
       "      <td>27</td>\n",
91
       "      <td>175</td>\n",
92
       "      <td>80</td>\n",
93
       "      <td>0</td>\n",
94
       "      <td>1</td>\n",
95
       "      <td>0</td>\n",
96
       "      <td>0</td>\n",
97
       "      <td>1</td>\n",
98
       "      <td>0</td>\n",
99
       "      <td>0</td>\n",
100
       "    </tr>\n",
101
       "    <tr>\n",
102
       "      <th>1</th>\n",
103
       "      <td>0.021986</td>\n",
104
       "      <td>0.015845</td>\n",
105
       "      <td>0.000000</td>\n",
106
       "      <td>0.071558</td>\n",
107
       "      <td>0.017352</td>\n",
108
       "      <td>0.020817</td>\n",
109
       "      <td>-0.037843</td>\n",
110
       "      <td>0.071558</td>\n",
111
       "      <td>0.000012</td>\n",
112
       "      <td>0.000014</td>\n",
113
       "      <td>...</td>\n",
114
       "      <td>27</td>\n",
115
       "      <td>175</td>\n",
116
       "      <td>80</td>\n",
117
       "      <td>0</td>\n",
118
       "      <td>1</td>\n",
119
       "      <td>0</td>\n",
120
       "      <td>0</td>\n",
121
       "      <td>1</td>\n",
122
       "      <td>0</td>\n",
123
       "      <td>0</td>\n",
124
       "    </tr>\n",
125
       "    <tr>\n",
126
       "      <th>2</th>\n",
127
       "      <td>0.020839</td>\n",
128
       "      <td>0.011034</td>\n",
129
       "      <td>0.002752</td>\n",
130
       "      <td>0.054356</td>\n",
131
       "      <td>0.020839</td>\n",
132
       "      <td>0.011034</td>\n",
133
       "      <td>0.002752</td>\n",
134
       "      <td>0.054356</td>\n",
135
       "      <td>0.000014</td>\n",
136
       "      <td>0.000008</td>\n",
137
       "      <td>...</td>\n",
138
       "      <td>27</td>\n",
139
       "      <td>175</td>\n",
140
       "      <td>80</td>\n",
141
       "      <td>0</td>\n",
142
       "      <td>1</td>\n",
143
       "      <td>0</td>\n",
144
       "      <td>0</td>\n",
145
       "      <td>1</td>\n",
146
       "      <td>0</td>\n",
147
       "      <td>0</td>\n",
148
       "    </tr>\n",
149
       "    <tr>\n",
150
       "      <th>3</th>\n",
151
       "      <td>0.034449</td>\n",
152
       "      <td>0.003185</td>\n",
153
       "      <td>0.013761</td>\n",
154
       "      <td>0.040595</td>\n",
155
       "      <td>0.034449</td>\n",
156
       "      <td>0.003185</td>\n",
157
       "      <td>0.013761</td>\n",
158
       "      <td>0.040595</td>\n",
159
       "      <td>0.000024</td>\n",
160
       "      <td>0.000002</td>\n",
161
       "      <td>...</td>\n",
162
       "      <td>27</td>\n",
163
       "      <td>175</td>\n",
164
       "      <td>80</td>\n",
165
       "      <td>0</td>\n",
166
       "      <td>1</td>\n",
167
       "      <td>0</td>\n",
168
       "      <td>0</td>\n",
169
       "      <td>1</td>\n",
170
       "      <td>0</td>\n",
171
       "      <td>0</td>\n",
172
       "    </tr>\n",
173
       "  </tbody>\n",
174
       "</table>\n",
175
       "<p>4 rows × 58 columns</p>\n",
176
       "</div>"
177
      ],
178
      "text/plain": [
179
       "   net_acc_mean  net_acc_std  net_acc_min  net_acc_max  ACC_x_mean  ACC_x_std  \\\n",
180
       "0      0.029937     0.009942     0.000000     0.087383    0.029510   0.011145   \n",
181
       "1      0.021986     0.015845     0.000000     0.071558    0.017352   0.020817   \n",
182
       "2      0.020839     0.011034     0.002752     0.054356    0.020839   0.011034   \n",
183
       "3      0.034449     0.003185     0.013761     0.040595    0.034449   0.003185   \n",
184
       "\n",
185
       "   ACC_x_min  ACC_x_max  ACC_y_mean  ACC_y_std  ...  age  height  weight  \\\n",
186
       "0  -0.024082   0.087383    0.000020   0.000008  ...   27     175      80   \n",
187
       "1  -0.037843   0.071558    0.000012   0.000014  ...   27     175      80   \n",
188
       "2   0.002752   0.054356    0.000014   0.000008  ...   27     175      80   \n",
189
       "3   0.013761   0.040595    0.000024   0.000002  ...   27     175      80   \n",
190
       "\n",
191
       "   gender_ female  gender_ male  coffee_today_YES  sport_today_YES  smoker_NO  \\\n",
192
       "0               0             1                 0                0          1   \n",
193
       "1               0             1                 0                0          1   \n",
194
       "2               0             1                 0                0          1   \n",
195
       "3               0             1                 0                0          1   \n",
196
       "\n",
197
       "   smoker_YES  feel_ill_today_YES  \n",
198
       "0           0                   0  \n",
199
       "1           0                   0  \n",
200
       "2           0                   0  \n",
201
       "3           0                   0  \n",
202
       "\n",
203
       "[4 rows x 58 columns]"
204
      ]
205
     },
206
     "execution_count": 46,
207
     "metadata": {},
208
     "output_type": "execute_result"
209
    }
210
   ],
211
   "source": [
212
    "df = pd.read_csv('data/merged.csv', index_col=0)\n",
213
    "df.head(4)"
214
   ]
215
  },
216
  {
217
   "cell_type": "code",
218
   "execution_count": 47,
219
   "metadata": {},
220
   "outputs": [
221
    {
222
     "data": {
223
      "text/plain": [
224
       "Index(['net_acc_mean', 'net_acc_std', 'net_acc_min', 'net_acc_max',\n",
225
       "       'ACC_x_mean', 'ACC_x_std', 'ACC_x_min', 'ACC_x_max', 'ACC_y_mean',\n",
226
       "       'ACC_y_std', 'ACC_y_min', 'ACC_y_max', 'ACC_z_mean', 'ACC_z_std',\n",
227
       "       'ACC_z_min', 'ACC_z_max', 'BVP_mean', 'BVP_std', 'BVP_min', 'BVP_max',\n",
228
       "       'EDA_mean', 'EDA_std', 'EDA_min', 'EDA_max', 'EDA_phasic_mean',\n",
229
       "       'EDA_phasic_std', 'EDA_phasic_min', 'EDA_phasic_max', 'EDA_smna_mean',\n",
230
       "       'EDA_smna_std', 'EDA_smna_min', 'EDA_smna_max', 'EDA_tonic_mean',\n",
231
       "       'EDA_tonic_std', 'EDA_tonic_min', 'EDA_tonic_max', 'Resp_mean',\n",
232
       "       'Resp_std', 'Resp_min', 'Resp_max', 'TEMP_mean', 'TEMP_std', 'TEMP_min',\n",
233
       "       'TEMP_max', 'BVP_peak_freq', 'TEMP_slope', 'subject', 'label', 'age',\n",
234
       "       'height', 'weight', 'gender_ female', 'gender_ male',\n",
235
       "       'coffee_today_YES', 'sport_today_YES', 'smoker_NO', 'smoker_YES',\n",
236
       "       'feel_ill_today_YES'],\n",
237
       "      dtype='object')"
238
      ]
239
     },
240
     "execution_count": 47,
241
     "metadata": {},
242
     "output_type": "execute_result"
243
    }
244
   ],
245
   "source": [
246
    "df.columns"
247
   ]
248
  },
249
  {
250
   "cell_type": "code",
251
   "execution_count": 48,
252
   "metadata": {},
253
   "outputs": [
254
    {
255
     "data": {
256
      "text/plain": [
257
       "array([0, 1, 2])"
258
      ]
259
     },
260
     "execution_count": 48,
261
     "metadata": {},
262
     "output_type": "execute_result"
263
    }
264
   ],
265
   "source": [
266
    "np.unique(df['label'])"
267
   ]
268
  },
269
  {
270
   "cell_type": "code",
271
   "execution_count": 49,
272
   "metadata": {},
273
   "outputs": [],
274
   "source": [
275
    "labels = {\n",
276
    "    0: \"Amused\",\n",
277
    "    1: \"Neutral\",\n",
278
    "    2: \"Stressed\"\n",
279
    "}"
280
   ]
281
  },
282
  {
283
   "cell_type": "markdown",
284
   "metadata": {},
285
   "source": [
286
    "### Feature Selection"
287
   ]
288
  },
289
  {
290
   "cell_type": "code",
291
   "execution_count": null,
292
   "metadata": {
293
    "tags": [
294
     "nbconvert_remove"
295
    ]
296
   },
297
   "outputs": [],
298
   "source": [
299
    "# nbconvert_remove\n",
300
    "\n",
301
    "import matplotlib.pyplot as plt\n",
302
    "import seaborn as sns\n",
303
    "\n",
304
    "plt.figure(figsize=(2,100))\n",
305
    "cor = df.corr()\n",
306
    "n_targets = len(df.columns)\n",
307
    "cor_target = cor['label'].values.reshape(n_targets, 1)\n",
308
    "cor_features = cor['label'].keys()\n",
309
    "ax = sns.heatmap(cor_target, annot=True, cmap=plt.cm.Accent_r)\n",
310
    "ax.set_yticklabels(cor_features)\n",
311
    "plt.show()"
312
   ]
313
  },
314
  {
315
   "cell_type": "code",
316
   "execution_count": 51,
317
   "metadata": {},
318
   "outputs": [
319
    {
320
     "data": {
321
      "text/plain": [
322
       "((1178, 15), (1178,))"
323
      ]
324
     },
325
     "execution_count": 51,
326
     "metadata": {},
327
     "output_type": "execute_result"
328
    }
329
   ],
330
   "source": [
331
    "selected_feats =   [\n",
332
    "    'BVP_mean', 'BVP_std', 'EDA_phasic_mean', 'EDA_phasic_min', 'EDA_smna_min', \n",
333
    "    'EDA_tonic_mean', 'Resp_mean', 'Resp_std', 'TEMP_mean', 'TEMP_std', 'TEMP_slope',\n",
334
    "    'BVP_peak_freq', 'age', 'height', 'weight'\n",
335
    "    ]\n",
336
    "\n",
337
    "X = df[selected_feats]\n",
338
    "y = df['label']\n",
339
    "\n",
340
    "X.shape, y.shape"
341
   ]
342
  },
343
  {
344
   "cell_type": "markdown",
345
   "metadata": {},
346
   "source": [
347
    "## ML Model"
348
   ]
349
  },
350
  {
351
   "cell_type": "code",
352
   "execution_count": 52,
353
   "metadata": {},
354
   "outputs": [],
355
   "source": [
356
    "from sklearn.model_selection import train_test_split\n",
357
    "from sklearn.ensemble import RandomForestClassifier"
358
   ]
359
  },
360
  {
361
   "cell_type": "code",
362
   "execution_count": 53,
363
   "metadata": {},
364
   "outputs": [
365
    {
366
     "data": {
367
      "text/plain": [
368
       "((1060, 15), (118, 15))"
369
      ]
370
     },
371
     "execution_count": 53,
372
     "metadata": {},
373
     "output_type": "execute_result"
374
    }
375
   ],
376
   "source": [
377
    "X_train, X_test = train_test_split(X, test_size=0.1, random_state=0)\n",
378
    "y_train, y_test = train_test_split(y, test_size=0.1, random_state=0)\n",
379
    "\n",
380
    "X_train.shape, X_test.shape"
381
   ]
382
  },
383
  {
384
   "cell_type": "code",
385
   "execution_count": null,
386
   "metadata": {},
387
   "outputs": [],
388
   "source": [
389
    "model = RandomForestClassifier()\n",
390
    "model.fit(X_train,y_train)"
391
   ]
392
  },
393
  {
394
   "cell_type": "code",
395
   "execution_count": 65,
396
   "metadata": {},
397
   "outputs": [],
398
   "source": [
399
    "def accuracy(predicted, actual):\n",
400
    "    n = 0\n",
401
    "    for p, a in zip(predicted, actual):\n",
402
    "        if p == a:\n",
403
    "            n += 1\n",
404
    "    return n/len(predicted) * 100"
405
   ]
406
  },
407
  {
408
   "cell_type": "code",
409
   "execution_count": 66,
410
   "metadata": {},
411
   "outputs": [],
412
   "source": [
413
    "def predict(arr):\n",
414
    "    arr = np.array(arr)\n",
415
    "\n",
416
    "    global model\n",
417
    "    result = model.predict(arr.reshape(1,-1)).flatten()\n",
418
    "    # _prob = model.predict_proba(arr.reshape(1,-1)).flatten()\n",
419
    "    return result"
420
   ]
421
  },
422
  {
423
   "cell_type": "code",
424
   "execution_count": 68,
425
   "metadata": {
426
    "tags": [
427
     "nbconvert_remove"
428
    ]
429
   },
430
   "outputs": [
431
    {
432
     "data": {
433
      "text/plain": [
434
       "95.76271186440678"
435
      ]
436
     },
437
     "execution_count": 68,
438
     "metadata": {},
439
     "output_type": "execute_result"
440
    }
441
   ],
442
   "source": [
443
    "predicted = []\n",
444
    "for data in X_test.values:\n",
445
    "    predicted.append(predict(data))\n",
446
    "predicted\n",
447
    "\n",
448
    "accuracy(predicted, y_test.values)"
449
   ]
450
  },
451
  {
452
   "cell_type": "markdown",
453
   "metadata": {},
454
   "source": [
455
    "### Saving the trained model in a pickle file to be later used by the API function to predict"
456
   ]
457
  },
458
  {
459
   "cell_type": "code",
460
   "execution_count": 69,
461
   "metadata": {},
462
   "outputs": [],
463
   "source": [
464
    "import pickle\n",
465
    "\n",
466
    "filename = 'trained_model.sav'\n",
467
    "pickle.dump(model, open(filename, 'wb'))"
468
   ]
469
  }
470
 ],
471
 "metadata": {
472
  "kernelspec": {
473
   "display_name": "Python 3.10.7 ('venv': venv)",
474
   "language": "python",
475
   "name": "python3"
476
  },
477
  "language_info": {
478
   "codemirror_mode": {
479
    "name": "ipython",
480
    "version": 3
481
   },
482
   "file_extension": ".py",
483
   "mimetype": "text/x-python",
484
   "name": "python",
485
   "nbconvert_exporter": "python",
486
   "pygments_lexer": "ipython3",
487
   "version": "3.10.7"
488
  },
489
  "orig_nbformat": 4,
490
  "vscode": {
491
   "interpreter": {
492
    "hash": "b56aae866cfdd3dd1993badfb61811822ff858e2a83b734b90ea6aa544e22f54"
493
   }
494
  }
495
 },
496
 "nbformat": 4,
497
 "nbformat_minor": 2
498
}