a b/SWELL-KW/SWELL-KW_LSTM.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "# SWELL-KW LSTM"
8
   ]
9
  },
10
  {
11
   "cell_type": "markdown",
12
   "metadata": {},
13
   "source": [
14
    "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
15
   ]
16
  },
17
  {
18
   "cell_type": "markdown",
19
   "metadata": {},
20
   "source": [
21
    "## Imports"
22
   ]
23
  },
24
  {
25
   "cell_type": "code",
26
   "execution_count": 1,
27
   "metadata": {
28
    "ExecuteTime": {
29
     "end_time": "2019-07-15T18:30:17.522073Z",
30
     "start_time": "2019-07-15T18:30:17.217355Z"
31
    },
32
    "scrolled": true
33
   },
34
   "outputs": [],
35
   "source": [
36
    "import pandas as pd\n",
37
    "import numpy as np\n",
38
    "from tabulate import tabulate\n",
39
    "import os\n",
40
    "import datetime as datetime\n",
41
    "import pickle as pkl\n",
42
    "import pathlib"
43
   ]
44
  },
45
  {
46
   "cell_type": "markdown",
47
   "metadata": {},
48
   "source": [
49
    "## DataFrames from CSVs"
50
   ]
51
  },
52
  {
53
   "cell_type": "code",
54
   "execution_count": 224,
55
   "metadata": {
56
    "ExecuteTime": {
57
     "end_time": "2019-07-15T18:30:28.212447Z",
58
     "start_time": "2019-07-15T18:30:17.889941Z"
59
    }
60
   },
61
   "outputs": [],
62
   "source": [
63
    "df = pd.read_csv('/home/sf/data/SWELL-KW/EDA_New.csv', index_col=0)"
64
   ]
65
  },
66
  {
67
   "cell_type": "markdown",
68
   "metadata": {},
69
   "source": [
70
    "## Preprocessing "
71
   ]
72
  },
73
  {
74
   "cell_type": "code",
75
   "execution_count": 226,
76
   "metadata": {},
77
   "outputs": [],
78
   "source": [
79
    "df['condition'].replace('time pressure', 'interruption', inplace=True)"
80
   ]
81
  },
82
  {
83
   "cell_type": "code",
84
   "execution_count": 227,
85
   "metadata": {},
86
   "outputs": [],
87
   "source": [
88
    "df['condition'].replace('interruption', 1, inplace=True)\n",
89
    "df['condition'].replace('no stress', 0, inplace=True)"
90
   ]
91
  },
92
  {
93
   "cell_type": "markdown",
94
   "metadata": {},
95
   "source": [
96
    "## Split Ground Truth "
97
   ]
98
  },
99
  {
100
   "cell_type": "code",
101
   "execution_count": 188,
102
   "metadata": {},
103
   "outputs": [
104
    {
105
     "data": {
106
      "text/plain": [
107
       "0.027338218248412812"
108
      ]
109
     },
110
     "execution_count": 188,
111
     "metadata": {},
112
     "output_type": "execute_result"
113
    }
114
   ],
115
   "source": [
116
    "df[df['condition']==1]['RANGE'].mean()"
117
   ]
118
  },
119
  {
120
   "cell_type": "code",
121
   "execution_count": 229,
122
   "metadata": {
123
    "ExecuteTime": {
124
     "end_time": "2019-07-15T18:30:37.760226Z",
125
     "start_time": "2019-07-15T18:30:37.597908Z"
126
    }
127
   },
128
   "outputs": [],
129
   "source": [
130
    "filtered_train = df.drop(['condition'], axis=1)\n",
131
    "filtered_target = df['condition']"
132
   ]
133
  },
134
  {
135
   "cell_type": "code",
136
   "execution_count": 231,
137
   "metadata": {
138
    "ExecuteTime": {
139
     "end_time": "2019-07-15T18:32:27.581110Z",
140
     "start_time": "2019-07-15T18:32:27.576162Z"
141
    }
142
   },
143
   "outputs": [
144
    {
145
     "name": "stdout",
146
     "output_type": "stream",
147
     "text": [
148
      "(104040,)\n",
149
      "(104040, 22)\n"
150
     ]
151
    }
152
   ],
153
   "source": [
154
    "print(filtered_target.shape)\n",
155
    "print(filtered_train.shape)"
156
   ]
157
  },
158
  {
159
   "cell_type": "code",
160
   "execution_count": 232,
161
   "metadata": {},
162
   "outputs": [
163
    {
164
     "data": {
165
      "text/plain": [
166
       "0    52695\n",
167
       "1    51345\n",
168
       "Name: condition, dtype: int64"
169
      ]
170
     },
171
     "execution_count": 232,
172
     "metadata": {},
173
     "output_type": "execute_result"
174
    }
175
   ],
176
   "source": [
177
    "pd.value_counts(filtered_target)"
178
   ]
179
  },
180
  {
181
   "cell_type": "code",
182
   "execution_count": 233,
183
   "metadata": {
184
    "ExecuteTime": {
185
     "end_time": "2019-06-19T05:24:41.449331Z",
186
     "start_time": "2019-06-19T05:24:41.445741Z"
187
    },
188
    "scrolled": true
189
   },
190
   "outputs": [],
191
   "source": [
192
    "y = filtered_target.values.reshape(-1, 20)"
193
   ]
194
  },
195
  {
196
   "cell_type": "markdown",
197
   "metadata": {},
198
   "source": [
199
    "## Convert to 3D - (Bags, Timesteps, Features)"
200
   ]
201
  },
202
  {
203
   "cell_type": "code",
204
   "execution_count": 234,
205
   "metadata": {},
206
   "outputs": [
207
    {
208
     "data": {
209
      "text/plain": [
210
       "22"
211
      ]
212
     },
213
     "execution_count": 234,
214
     "metadata": {},
215
     "output_type": "execute_result"
216
    }
217
   ],
218
   "source": [
219
    "len(filtered_train.columns)"
220
   ]
221
  },
222
  {
223
   "cell_type": "code",
224
   "execution_count": 235,
225
   "metadata": {
226
    "ExecuteTime": {
227
     "end_time": "2019-06-19T05:24:43.963421Z",
228
     "start_time": "2019-06-19T05:24:43.944207Z"
229
    }
230
   },
231
   "outputs": [
232
    {
233
     "name": "stdout",
234
     "output_type": "stream",
235
     "text": [
236
      "(104040, 22)\n",
237
      "(5202, 20, 22)\n"
238
     ]
239
    }
240
   ],
241
   "source": [
242
    "x = filtered_train.values\n",
243
    "print(x.shape)\n",
244
    "x = x.reshape(int(len(x) / 20), 20, 22)\n",
245
    "print(x.shape)"
246
   ]
247
  },
248
  {
249
   "cell_type": "markdown",
250
   "metadata": {},
251
   "source": [
252
    "## Filter Overlapping Bags"
253
   ]
254
  },
255
  {
256
   "cell_type": "code",
257
   "execution_count": 236,
258
   "metadata": {
259
    "ExecuteTime": {
260
     "end_time": "2019-06-19T05:25:19.299273Z",
261
     "start_time": "2019-06-19T05:25:18.952971Z"
262
    }
263
   },
264
   "outputs": [
265
    {
266
     "name": "stdout",
267
     "output_type": "stream",
268
     "text": [
269
      "[120, 198, 289, 408, 410, 415, 544, 667, 776, 894, 908, 911, 913, 1019, 1103, 1117, 1147, 1150, 1191, 1305, 1440, 1549, 1643, 1660, 1685, 1805, 1937, 2042, 2175, 2179, 2183, 2184, 2384, 2490, 2566, 2571, 2594, 2597, 2637, 2638, 2761, 2810, 2918, 2990, 3004, 3008, 3015, 3056, 3164, 3236, 3244, 3274, 3277, 3318, 3433, 3546, 3556, 3560, 3562, 3563, 3564, 3676, 3786, 3796, 3800, 3839, 3840, 3841, 3956, 4051, 4081, 4084, 4125, 4126, 4232, 4310, 4411, 4527, 4544, 4546, 4549, 4554, 4559, 4564, 4687, 4789, 4895, 4968, 4985, 4989, 5008, 5131]\n"
270
     ]
271
    }
272
   ],
273
   "source": [
274
    "# filtering bags that overlap with another class\n",
275
    "bags_to_remove = []\n",
276
    "for i in range(len(y)):\n",
277
    "    if len(set(y[i])) > 1:\n",
278
    "        bags_to_remove.append(i)\n",
279
    "print(bags_to_remove)"
280
   ]
281
  },
282
  {
283
   "cell_type": "code",
284
   "execution_count": 237,
285
   "metadata": {
286
    "ExecuteTime": {
287
     "end_time": "2019-06-19T05:25:26.446153Z",
288
     "start_time": "2019-06-19T05:25:26.256245Z"
289
    }
290
   },
291
   "outputs": [],
292
   "source": [
293
    "x = np.delete(x, bags_to_remove, axis=0)\n",
294
    "y = np.delete(y, bags_to_remove, axis=0)"
295
   ]
296
  },
297
  {
298
   "cell_type": "code",
299
   "execution_count": 238,
300
   "metadata": {
301
    "ExecuteTime": {
302
     "end_time": "2019-06-19T05:25:27.260726Z",
303
     "start_time": "2019-06-19T05:25:27.254474Z"
304
    }
305
   },
306
   "outputs": [
307
    {
308
     "data": {
309
      "text/plain": [
310
       "(5110, 20, 22)"
311
      ]
312
     },
313
     "execution_count": 238,
314
     "metadata": {},
315
     "output_type": "execute_result"
316
    }
317
   ],
318
   "source": [
319
    "x.shape "
320
   ]
321
  },
322
  {
323
   "cell_type": "code",
324
   "execution_count": 239,
325
   "metadata": {
326
    "ExecuteTime": {
327
     "end_time": "2019-06-19T05:25:28.101475Z",
328
     "start_time": "2019-06-19T05:25:28.096009Z"
329
    }
330
   },
331
   "outputs": [
332
    {
333
     "data": {
334
      "text/plain": [
335
       "(5110, 20)"
336
      ]
337
     },
338
     "execution_count": 239,
339
     "metadata": {},
340
     "output_type": "execute_result"
341
    }
342
   ],
343
   "source": [
344
    "y.shape"
345
   ]
346
  },
347
  {
348
   "cell_type": "markdown",
349
   "metadata": {},
350
   "source": [
351
    "## Categorical Representation "
352
   ]
353
  },
354
  {
355
   "cell_type": "code",
356
   "execution_count": 240,
357
   "metadata": {
358
    "ExecuteTime": {
359
     "end_time": "2019-06-19T05:25:50.094089Z",
360
     "start_time": "2019-06-19T05:25:49.746284Z"
361
    }
362
   },
363
   "outputs": [],
364
   "source": [
365
    "one_hot_list = []\n",
366
    "for i in range(len(y)):\n",
367
    "    one_hot_list.append(set(y[i]).pop())"
368
   ]
369
  },
370
  {
371
   "cell_type": "code",
372
   "execution_count": 241,
373
   "metadata": {
374
    "ExecuteTime": {
375
     "end_time": "2019-06-19T05:25:52.203633Z",
376
     "start_time": "2019-06-19T05:25:52.198467Z"
377
    }
378
   },
379
   "outputs": [],
380
   "source": [
381
    "categorical_y_ver = one_hot_list\n",
382
    "categorical_y_ver = np.array(categorical_y_ver)"
383
   ]
384
  },
385
  {
386
   "cell_type": "code",
387
   "execution_count": 242,
388
   "metadata": {
389
    "ExecuteTime": {
390
     "end_time": "2019-06-19T05:25:53.132006Z",
391
     "start_time": "2019-06-19T05:25:53.126314Z"
392
    }
393
   },
394
   "outputs": [
395
    {
396
     "data": {
397
      "text/plain": [
398
       "(5110,)"
399
      ]
400
     },
401
     "execution_count": 242,
402
     "metadata": {},
403
     "output_type": "execute_result"
404
    }
405
   ],
406
   "source": [
407
    "categorical_y_ver.shape"
408
   ]
409
  },
410
  {
411
   "cell_type": "code",
412
   "execution_count": 243,
413
   "metadata": {
414
    "ExecuteTime": {
415
     "end_time": "2019-06-19T05:25:54.495163Z",
416
     "start_time": "2019-06-19T05:25:54.489349Z"
417
    }
418
   },
419
   "outputs": [
420
    {
421
     "data": {
422
      "text/plain": [
423
       "20"
424
      ]
425
     },
426
     "execution_count": 243,
427
     "metadata": {},
428
     "output_type": "execute_result"
429
    }
430
   ],
431
   "source": [
432
    "x.shape[1]"
433
   ]
434
  },
435
  {
436
   "cell_type": "code",
437
   "execution_count": 244,
438
   "metadata": {
439
    "ExecuteTime": {
440
     "end_time": "2019-06-19T05:26:08.021392Z",
441
     "start_time": "2019-06-19T05:26:08.017038Z"
442
    }
443
   },
444
   "outputs": [],
445
   "source": [
446
    "def one_hot(y, numOutput):\n",
447
    "    y = np.reshape(y, [-1])\n",
448
    "    ret = np.zeros([y.shape[0], numOutput])\n",
449
    "    for i, label in enumerate(y):\n",
450
    "        ret[i, label] = 1\n",
451
    "    return ret"
452
   ]
453
  },
454
  {
455
   "cell_type": "markdown",
456
   "metadata": {},
457
   "source": [
458
    "## Extract 3D Normalized Data with Validation Set"
459
   ]
460
  },
461
  {
462
   "cell_type": "code",
463
   "execution_count": 245,
464
   "metadata": {
465
    "ExecuteTime": {
466
     "end_time": "2019-06-19T05:26:08.931435Z",
467
     "start_time": "2019-06-19T05:26:08.927397Z"
468
    }
469
   },
470
   "outputs": [],
471
   "source": [
472
    "from sklearn.model_selection import train_test_split\n",
473
    "import pathlib"
474
   ]
475
  },
476
  {
477
   "cell_type": "code",
478
   "execution_count": 246,
479
   "metadata": {
480
    "ExecuteTime": {
481
     "end_time": "2019-06-19T05:26:10.295822Z",
482
     "start_time": "2019-06-19T05:26:09.832723Z"
483
    }
484
   },
485
   "outputs": [],
486
   "source": [
487
    "x_train_val_combined, x_test, y_train_val_combined, y_test = train_test_split(x, categorical_y_ver, test_size=0.20, random_state=42)"
488
   ]
489
  },
490
  {
491
   "cell_type": "code",
492
   "execution_count": 247,
493
   "metadata": {
494
    "ExecuteTime": {
495
     "end_time": "2019-06-19T05:26:11.260084Z",
496
     "start_time": "2019-06-19T05:26:11.253714Z"
497
    }
498
   },
499
   "outputs": [
500
    {
501
     "data": {
502
      "text/plain": [
503
       "array([1, 1, 0, ..., 1, 0, 1])"
504
      ]
505
     },
506
     "execution_count": 247,
507
     "metadata": {},
508
     "output_type": "execute_result"
509
    }
510
   ],
511
   "source": [
512
    "y_test"
513
   ]
514
  },
515
  {
516
   "cell_type": "code",
517
   "execution_count": 248,
518
   "metadata": {
519
    "ExecuteTime": {
520
     "end_time": "2019-06-18T17:24:41.832337Z",
521
     "start_time": "2019-06-18T17:24:41.141678Z"
522
    }
523
   },
524
   "outputs": [],
525
   "source": [
526
    "extractedDir = '/home/sf/data/SWELL-KW/'\n",
527
    "# def generateData(extractedDir):\n",
528
    "# x_train_val_combined, x_test, y_train_val_combined, y_test = readData(extractedDir)\n",
529
    "timesteps = x_train_val_combined.shape[-2] #128/256\n",
530
    "feats = x_train_val_combined.shape[-1]  #16\n",
531
    "\n",
532
    "trainSize = int(x_train_val_combined.shape[0]*0.9) #6566\n",
533
    "x_train, x_val = x_train_val_combined[:trainSize], x_train_val_combined[trainSize:] \n",
534
    "y_train, y_val = y_train_val_combined[:trainSize], y_train_val_combined[trainSize:]\n",
535
    "\n",
536
    "# normalization\n",
537
    "x_train = np.reshape(x_train, [-1, feats])\n",
538
    "mean = np.mean(x_train, axis=0)\n",
539
    "std = np.std(x_train, axis=0)\n",
540
    "\n",
541
    "# normalize train\n",
542
    "x_train = x_train - mean\n",
543
    "x_train = x_train / std\n",
544
    "x_train = np.reshape(x_train, [-1, timesteps, feats])\n",
545
    "\n",
546
    "# normalize val\n",
547
    "x_val = np.reshape(x_val, [-1, feats])\n",
548
    "x_val = x_val - mean\n",
549
    "x_val = x_val / std\n",
550
    "x_val = np.reshape(x_val, [-1, timesteps, feats])\n",
551
    "\n",
552
    "# normalize test\n",
553
    "x_test = np.reshape(x_test, [-1, feats])\n",
554
    "x_test = x_test - mean\n",
555
    "x_test = x_test / std\n",
556
    "x_test = np.reshape(x_test, [-1, timesteps, feats])\n",
557
    "\n",
558
    "# shuffle test, as this was remaining\n",
559
    "idx = np.arange(len(x_test))\n",
560
    "np.random.shuffle(idx)\n",
561
    "x_test = x_test[idx]\n",
562
    "y_test = y_test[idx]"
563
   ]
564
  },
565
  {
566
   "cell_type": "code",
567
   "execution_count": 249,
568
   "metadata": {
569
    "ExecuteTime": {
570
     "end_time": "2019-06-18T17:25:50.674962Z",
571
     "start_time": "2019-06-18T17:25:50.481068Z"
572
    }
573
   },
574
   "outputs": [
575
    {
576
     "name": "stdout",
577
     "output_type": "stream",
578
     "text": [
579
      "/home/sf/data/SWELL-KW//\n"
580
     ]
581
    }
582
   ],
583
   "source": [
584
    "# one-hot encoding of labels\n",
585
    "numOutput = 2\n",
586
    "y_train = one_hot(y_train, numOutput)\n",
587
    "y_val = one_hot(y_val, numOutput)\n",
588
    "y_test = one_hot(y_test, numOutput)\n",
589
    "extractedDir += '/'\n",
590
    "\n",
591
    "pathlib.Path(extractedDir + 'RAW').mkdir(parents=True, exist_ok = True)\n",
592
    "\n",
593
    "np.save(extractedDir + \"RAW/x_train\", x_train)\n",
594
    "np.save(extractedDir + \"RAW/y_train\", y_train)\n",
595
    "np.save(extractedDir + \"RAW/x_test\", x_test)\n",
596
    "np.save(extractedDir + \"RAW/y_test\", y_test)\n",
597
    "np.save(extractedDir + \"RAW/x_val\", x_val)\n",
598
    "np.save(extractedDir + \"RAW/y_val\", y_val)\n",
599
    "\n",
600
    "print(extractedDir)"
601
   ]
602
  },
603
  {
604
   "cell_type": "code",
605
   "execution_count": 250,
606
   "metadata": {
607
    "ExecuteTime": {
608
     "end_time": "2019-06-18T17:26:35.381650Z",
609
     "start_time": "2019-06-18T17:26:35.136645Z"
610
    }
611
   },
612
   "outputs": [
613
    {
614
     "name": "stdout",
615
     "output_type": "stream",
616
     "text": [
617
      "x_test.npy  x_train.npy  x_val.npy  y_test.npy  y_train.npy  y_val.npy\r\n"
618
     ]
619
    }
620
   ],
621
   "source": [
622
    "ls /home/sf/data/SWELL-KW/RAW"
623
   ]
624
  },
625
  {
626
   "cell_type": "code",
627
   "execution_count": 251,
628
   "metadata": {
629
    "ExecuteTime": {
630
     "end_time": "2019-06-18T17:27:34.458130Z",
631
     "start_time": "2019-06-18T17:27:34.323712Z"
632
    }
633
   },
634
   "outputs": [
635
    {
636
     "data": {
637
      "text/plain": [
638
       "(3679, 20, 22)"
639
      ]
640
     },
641
     "execution_count": 251,
642
     "metadata": {},
643
     "output_type": "execute_result"
644
    }
645
   ],
646
   "source": [
647
    "np.load('//home/sf/data/SWELL-KW/RAW/x_train.npy').shape"
648
   ]
649
  },
650
  {
651
   "cell_type": "markdown",
652
   "metadata": {},
653
   "source": [
654
    "## Make 4D EMI Data (Bags, Subinstances, Subinstance Length, Features)"
655
   ]
656
  },
657
  {
658
   "cell_type": "code",
659
   "execution_count": 2,
660
   "metadata": {
661
    "ExecuteTime": {
662
     "end_time": "2019-06-19T05:30:53.720179Z",
663
     "start_time": "2019-06-19T05:30:53.713756Z"
664
    }
665
   },
666
   "outputs": [],
667
   "source": [
668
    "def loadData(dirname):\n",
669
    "    x_train = np.load(dirname + '/' + 'x_train.npy')\n",
670
    "    y_train = np.load(dirname + '/' + 'y_train.npy')\n",
671
    "    x_test = np.load(dirname + '/' + 'x_test.npy')\n",
672
    "    y_test = np.load(dirname + '/' + 'y_test.npy')\n",
673
    "    x_val = np.load(dirname + '/' + 'x_val.npy')\n",
674
    "    y_val = np.load(dirname + '/' + 'y_val.npy')\n",
675
    "    return x_train, y_train, x_test, y_test, x_val, y_val"
676
   ]
677
  },
678
  {
679
   "cell_type": "code",
680
   "execution_count": 3,
681
   "metadata": {
682
    "ExecuteTime": {
683
     "end_time": "2019-06-18T17:29:23.614052Z",
684
     "start_time": "2019-06-18T17:29:23.601324Z"
685
    }
686
   },
687
   "outputs": [],
688
   "source": [
689
    "def bagData(X, Y, subinstanceLen, subinstanceStride):\n",
690
    "    numClass = 2\n",
691
    "    numSteps = 20\n",
692
    "    numFeats = 22\n",
693
    "    assert X.ndim == 3\n",
694
    "    assert X.shape[1] == numSteps\n",
695
    "    assert X.shape[2] == numFeats\n",
696
    "    assert subinstanceLen <= numSteps\n",
697
    "    assert subinstanceLen > 0\n",
698
    "    assert subinstanceStride <= numSteps\n",
699
    "    assert subinstanceStride >= 0\n",
700
    "    assert len(X) == len(Y)\n",
701
    "    assert Y.ndim == 2\n",
702
    "    assert Y.shape[1] == numClass\n",
703
    "    x_bagged = []\n",
704
    "    y_bagged = []\n",
705
    "    for i, point in enumerate(X[:, :, :]):\n",
706
    "        instanceList = []\n",
707
    "        start = 0\n",
708
    "        end = subinstanceLen\n",
709
    "        while True:\n",
710
    "            x = point[start:end, :]\n",
711
    "            if len(x) < subinstanceLen:\n",
712
    "                x_ = np.zeros([subinstanceLen, x.shape[1]])\n",
713
    "                x_[:len(x), :] = x[:, :]\n",
714
    "                x = x_\n",
715
    "            instanceList.append(x)\n",
716
    "            if end >= numSteps:\n",
717
    "                break\n",
718
    "            start += subinstanceStride\n",
719
    "            end += subinstanceStride\n",
720
    "        bag = np.array(instanceList)\n",
721
    "        numSubinstance = bag.shape[0]\n",
722
    "        label = Y[i]\n",
723
    "        label = np.argmax(label)\n",
724
    "        labelBag = np.zeros([numSubinstance, numClass])\n",
725
    "        labelBag[:, label] = 1\n",
726
    "        x_bagged.append(bag)\n",
727
    "        label = np.array(labelBag)\n",
728
    "        y_bagged.append(label)\n",
729
    "    return np.array(x_bagged), np.array(y_bagged)"
730
   ]
731
  },
732
  {
733
   "cell_type": "code",
734
   "execution_count": 4,
735
   "metadata": {
736
    "ExecuteTime": {
737
     "end_time": "2019-06-19T05:33:06.531994Z",
738
     "start_time": "2019-06-19T05:33:06.523884Z"
739
    }
740
   },
741
   "outputs": [],
742
   "source": [
743
    "def makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir):\n",
744
    "    x_train, y_train, x_test, y_test, x_val, y_val = loadData(sourceDir)\n",
745
    "    x, y = bagData(x_train, y_train, subinstanceLen, subinstanceStride)\n",
746
    "    np.save(outDir + '/x_train.npy', x)\n",
747
    "    np.save(outDir + '/y_train.npy', y)\n",
748
    "    print('Num train %d' % len(x))\n",
749
    "    x, y = bagData(x_test, y_test, subinstanceLen, subinstanceStride)\n",
750
    "    np.save(outDir + '/x_test.npy', x)\n",
751
    "    np.save(outDir + '/y_test.npy', y)\n",
752
    "    print('Num test %d' % len(x))\n",
753
    "    x, y = bagData(x_val, y_val, subinstanceLen, subinstanceStride)\n",
754
    "    np.save(outDir + '/x_val.npy', x)\n",
755
    "    np.save(outDir + '/y_val.npy', y)\n",
756
    "    print('Num val %d' % len(x))"
757
   ]
758
  },
759
  {
760
   "cell_type": "code",
761
   "execution_count": 5,
762
   "metadata": {
763
    "ExecuteTime": {
764
     "end_time": "2019-06-19T05:35:20.960014Z",
765
     "start_time": "2019-06-19T05:35:20.050363Z"
766
    }
767
   },
768
   "outputs": [
769
    {
770
     "name": "stdout",
771
     "output_type": "stream",
772
     "text": [
773
      "Num train 3679\n",
774
      "Num test 1022\n",
775
      "Num val 409\n"
776
     ]
777
    }
778
   ],
779
   "source": [
780
    "subinstanceLen = 8\n",
781
    "subinstanceStride = 3\n",
782
    "extractedDir = '/home/sf/data/SWELL-KW'\n",
783
    "from os import mkdir\n",
784
    "#mkdir('/home/sf/data/SWELL-KW/8_3')\n",
785
    "rawDir = extractedDir + '/RAW'\n",
786
    "sourceDir = rawDir\n",
787
    "outDir = extractedDir + '/%d_%d/' % (subinstanceLen, subinstanceStride)\n",
788
    "makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir)"
789
   ]
790
  },
791
  {
792
   "cell_type": "code",
793
   "execution_count": 6,
794
   "metadata": {
795
    "ExecuteTime": {
796
     "end_time": "2019-06-18T17:48:58.293843Z",
797
     "start_time": "2019-06-18T17:48:58.285383Z"
798
    }
799
   },
800
   "outputs": [
801
    {
802
     "data": {
803
      "text/plain": [
804
       "(3679, 5, 2)"
805
      ]
806
     },
807
     "execution_count": 6,
808
     "metadata": {},
809
     "output_type": "execute_result"
810
    }
811
   ],
812
   "source": [
813
    "np.load('/home/sf/data/SWELL-KW/8_3/y_train.npy').shape"
814
   ]
815
  },
816
  {
817
   "cell_type": "code",
818
   "execution_count": 8,
819
   "metadata": {
820
    "ExecuteTime": {
821
     "end_time": "2019-06-19T05:35:48.609552Z",
822
     "start_time": "2019-06-19T05:35:48.604291Z"
823
    }
824
   },
825
   "outputs": [],
826
   "source": [
827
    "from edgeml.graph.rnn import EMI_DataPipeline\n",
828
    "from edgeml.graph.rnn import EMI_BasicLSTM, EMI_FastGRNN, EMI_FastRNN, EMI_GRU\n",
829
    "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
830
    "import edgeml.utils"
831
   ]
832
  },
833
  {
834
   "cell_type": "code",
835
   "execution_count": 9,
836
   "metadata": {
837
    "ExecuteTime": {
838
     "end_time": "2019-06-19T05:43:48.032413Z",
839
     "start_time": "2019-06-19T05:43:47.987839Z"
840
    }
841
   },
842
   "outputs": [],
843
   "source": [
844
    "def lstm_experiment_generator(params, path = './DSAAR/64_16/'):\n",
845
    "    \"\"\"\n",
846
    "        Function that will generate the experiments to be run.\n",
847
    "        Inputs : \n",
848
    "        (1) Dictionary params, to set the network parameters.\n",
849
    "        (2) Name of the Model to be run from [EMI-LSTM, EMI-FastGRNN, EMI-GRU]\n",
850
    "        (3) Path to the dataset, where the csv files are present.\n",
851
    "    \"\"\"\n",
852
    "    \n",
853
    "    #Copy the contents of the params dictionary.\n",
854
    "    lstm_dict = {**params}\n",
855
    "    \n",
856
    "    #---------------------------PARAM SETTING----------------------#\n",
857
    "    \n",
858
    "    # Network parameters for our LSTM + FC Layer\n",
859
    "    NUM_HIDDEN = params[\"NUM_HIDDEN\"]\n",
860
    "    NUM_TIMESTEPS = params[\"NUM_TIMESTEPS\"]\n",
861
    "    ORIGINAL_NUM_TIMESTEPS = params[\"ORIGINAL_NUM_TIMESTEPS\"]\n",
862
    "    NUM_FEATS = params[\"NUM_FEATS\"]\n",
863
    "    FORGET_BIAS = params[\"FORGET_BIAS\"]\n",
864
    "    NUM_OUTPUT = params[\"NUM_OUTPUT\"]\n",
865
    "    USE_DROPOUT = True if (params[\"USE_DROPOUT\"] == 1) else False\n",
866
    "    KEEP_PROB = params[\"KEEP_PROB\"]\n",
867
    "\n",
868
    "    # For dataset API\n",
869
    "    PREFETCH_NUM = params[\"PREFETCH_NUM\"]\n",
870
    "    BATCH_SIZE = params[\"BATCH_SIZE\"]\n",
871
    "\n",
872
    "    # Number of epochs in *one iteration*\n",
873
    "    NUM_EPOCHS = params[\"NUM_EPOCHS\"]\n",
874
    "    # Number of iterations in *one round*. After each iteration,\n",
875
    "    # the model is dumped to disk. At the end of the current\n",
876
    "    # round, the best model among all the dumped models in the\n",
877
    "    # current round is picked up..\n",
878
    "    NUM_ITER = params[\"NUM_ITER\"]\n",
879
    "    # A round consists of multiple training iterations and a belief\n",
880
    "    # update step using the best model from all of these iterations\n",
881
    "    NUM_ROUNDS = params[\"NUM_ROUNDS\"]\n",
882
    "    LEARNING_RATE = params[\"LEARNING_RATE\"]\n",
883
    "\n",
884
    "    # A staging direcory to store models\n",
885
    "    MODEL_PREFIX = params[\"MODEL_PREFIX\"]\n",
886
    "    \n",
887
    "    #----------------------END OF PARAM SETTING----------------------#\n",
888
    "    \n",
889
    "    #----------------------DATA LOADING------------------------------#\n",
890
    "    \n",
891
    "    x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
892
    "    x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
893
    "    x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
894
    "\n",
895
    "    # BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
896
    "    # step of EMI/MI RNN\n",
897
    "    BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
898
    "    BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
899
    "    BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
900
    "    NUM_SUBINSTANCE = x_train.shape[1]\n",
901
    "    print(\"x_train shape is:\", x_train.shape)\n",
902
    "    print(\"y_train shape is:\", y_train.shape)\n",
903
    "    print(\"x_test shape is:\", x_val.shape)\n",
904
    "    print(\"y_test shape is:\", y_val.shape)\n",
905
    "    \n",
906
    "    #----------------------END OF DATA LOADING------------------------------#    \n",
907
    "    \n",
908
    "    #----------------------COMPUTATION GRAPH--------------------------------#\n",
909
    "    \n",
910
    "    # Define the linear secondary classifier\n",
911
    "    def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
912
    "        W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
913
    "        B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
914
    "        y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
915
    "        self.output = y_cap\n",
916
    "        self.graphCreated = True\n",
917
    "\n",
918
    "    def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
919
    "        y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
920
    "        self.output = y_cap\n",
921
    "        self.graphCreated = True\n",
922
    "\n",
923
    "    def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
924
    "        if inference is False:\n",
925
    "            feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
926
    "        else:\n",
927
    "            feedDict = {self._emiGraph.keep_prob: 1.0}\n",
928
    "        return feedDict\n",
929
    "\n",
930
    "    EMI_BasicLSTM._createExtendedGraph = createExtendedGraph\n",
931
    "    EMI_BasicLSTM._restoreExtendedGraph = restoreExtendedGraph\n",
932
    "\n",
933
    "    if USE_DROPOUT is True:\n",
934
    "        EMI_Driver.feedDictFunc = feedDictFunc\n",
935
    "    \n",
936
    "    inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
937
    "    emiLSTM = EMI_BasicLSTM(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
938
    "                            forgetBias=FORGET_BIAS, useDropout=USE_DROPOUT)\n",
939
    "    emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
940
    "                             stepSize=LEARNING_RATE)\n",
941
    "    \n",
942
    "    tf.reset_default_graph()\n",
943
    "    g1 = tf.Graph()    \n",
944
    "    with g1.as_default():\n",
945
    "        # Obtain the iterators to each batch of the data\n",
946
    "        x_batch, y_batch = inputPipeline()\n",
947
    "        # Create the forward computation graph based on the iterators\n",
948
    "        y_cap = emiLSTM(x_batch)\n",
949
    "        # Create loss graphs and training routines\n",
950
    "        emiTrainer(y_cap, y_batch)\n",
951
    "        \n",
952
    "    #------------------------------END OF COMPUTATION GRAPH------------------------------#\n",
953
    "    \n",
954
    "    #-------------------------------------EMI DRIVER-------------------------------------#\n",
955
    "        \n",
956
    "    with g1.as_default():\n",
957
    "        emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n",
958
    "\n",
959
    "    emiDriver.initializeSession(g1)\n",
960
    "    y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
961
    "                                          y_train=y_train, bag_train=BAG_TRAIN,\n",
962
    "                                          x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
963
    "                                          numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
964
    "                                          numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
965
    "                                          numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
966
    "                                          fracEMI=0.5, updatePolicy='top-k', k=1)\n",
967
    "    \n",
968
    "    #-------------------------------END OF EMI DRIVER-------------------------------------#\n",
969
    "    \n",
970
    "    #-----------------------------------EARLY SAVINGS-------------------------------------#\n",
971
    "    \n",
972
    "    \"\"\"\n",
973
    "        Early Prediction Policy: We make an early prediction based on the predicted classes\n",
974
    "        probability. If the predicted class probability > minProb at some step, we make\n",
975
    "        a prediction at that step.\n",
976
    "    \"\"\"\n",
977
    "    def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
978
    "        assert instanceOut.ndim == 2\n",
979
    "        classes = np.argmax(instanceOut, axis=1)\n",
980
    "        prob = np.max(instanceOut, axis=1)\n",
981
    "        index = np.where(prob >= minProb)[0]\n",
982
    "        if len(index) == 0:\n",
983
    "            assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
984
    "            return classes[-1], len(instanceOut) - 1\n",
985
    "        index = index[0]\n",
986
    "        return classes[index], index\n",
987
    "\n",
988
    "    def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
989
    "        predictionStep = predictionStep + 1\n",
990
    "        predictionStep = np.reshape(predictionStep, -1)\n",
991
    "        totalSteps = np.sum(predictionStep)\n",
992
    "        maxSteps = len(predictionStep) * numTimeSteps\n",
993
    "        savings = 1.0 - (totalSteps / maxSteps)\n",
994
    "        if returnTotal:\n",
995
    "            return savings, totalSteps\n",
996
    "        return savings\n",
997
    "    \n",
998
    "    #--------------------------------END OF EARLY SAVINGS---------------------------------#\n",
999
    "    \n",
1000
    "    #----------------------------------------BEST MODEL-----------------------------------#\n",
1001
    "    \n",
1002
    "    k = 2\n",
1003
    "    predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
1004
    "                                                                   minProb=0.99, keep_prob=1.0)\n",
1005
    "    bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
1006
    "    print('Accuracy at k = %d: %f' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
1007
    "    mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
1008
    "    emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
1009
    "    total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
1010
    "    print('Savings due to MI-RNN : %f' % mi_savings)\n",
1011
    "    print('Savings due to Early prediction: %f' % emi_savings)\n",
1012
    "    print('Total Savings: %f' % (total_savings))\n",
1013
    "    \n",
1014
    "    #Store in the dictionary.\n",
1015
    "    lstm_dict[\"k\"] = k\n",
1016
    "    lstm_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
1017
    "    lstm_dict[\"total_savings\"] = total_savings\n",
1018
    "    lstm_dict[\"y_test\"] = BAG_TEST\n",
1019
    "    lstm_dict[\"y_pred\"] = bagPredictions\n",
1020
    "    \n",
1021
    "    # A slightly more detailed analysis method is provided. \n",
1022
    "    df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
1023
    "    print (tabulate(df, headers=list(df.columns), tablefmt='grid'))\n",
1024
    "    \n",
1025
    "    lstm_dict[\"detailed analysis\"] = df\n",
1026
    "    #----------------------------------END OF BEST MODEL-----------------------------------#\n",
1027
    "    \n",
1028
    "    #----------------------------------PICKING THE BEST MODEL------------------------------#\n",
1029
    "    \n",
1030
    "    devnull = open(os.devnull, 'r')\n",
1031
    "    for val in modelStats:\n",
1032
    "        round_, acc, modelPrefix, globalStep = val\n",
1033
    "        emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
1034
    "        predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
1035
    "                                                                   minProb=0.99, keep_prob=1.0)\n",
1036
    "\n",
1037
    "        bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
1038
    "        print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
1039
    "        print(', Test Accuracy (k = %d): %f, ' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
1040
    "        print('Additional savings: %f' % getEarlySaving(predictionStep, NUM_TIMESTEPS)) \n",
1041
    "        \n",
1042
    "    \n",
1043
    "    #-------------------------------END OF PICKING THE BEST MODEL--------------------------#\n",
1044
    "\n",
1045
    "    return lstm_dict"
1046
   ]
1047
  },
1048
  {
1049
   "cell_type": "code",
1050
   "execution_count": 10,
1051
   "metadata": {
1052
    "ExecuteTime": {
1053
     "end_time": "2019-06-19T05:43:51.798834Z",
1054
     "start_time": "2019-06-19T05:43:51.792938Z"
1055
    }
1056
   },
1057
   "outputs": [],
1058
   "source": [
1059
    "def experiment_generator(params, path, model = 'lstm'):\n",
1060
    "    \n",
1061
    "    \n",
1062
    "    if (model == 'lstm'): return lstm_experiment_generator(params, path)\n",
1063
    "    elif (model == 'fastgrnn'): return fastgrnn_experiment_generator(params, path)\n",
1064
    "    elif (model == 'gru'): return gru_experiment_generator(params, path)\n",
1065
    "    elif (model == 'baseline'): return baseline_experiment_generator(params, path)\n",
1066
    "    \n",
1067
    "    return "
1068
   ]
1069
  },
1070
  {
1071
   "cell_type": "code",
1072
   "execution_count": 11,
1073
   "metadata": {},
1074
   "outputs": [
1075
    {
1076
     "data": {
1077
      "text/plain": [
1078
       "'/home/sf/data/EdgeML/tf/examples/EMI-RNN'"
1079
      ]
1080
     },
1081
     "execution_count": 11,
1082
     "metadata": {},
1083
     "output_type": "execute_result"
1084
    }
1085
   ],
1086
   "source": [
1087
    "pwd"
1088
   ]
1089
  },
1090
  {
1091
   "cell_type": "code",
1092
   "execution_count": 12,
1093
   "metadata": {
1094
    "ExecuteTime": {
1095
     "end_time": "2019-06-19T05:43:53.168413Z",
1096
     "start_time": "2019-06-19T05:43:53.164622Z"
1097
    }
1098
   },
1099
   "outputs": [],
1100
   "source": [
1101
    "import tensorflow as tf"
1102
   ]
1103
  },
1104
  {
1105
   "cell_type": "code",
1106
   "execution_count": 13,
1107
   "metadata": {
1108
    "ExecuteTime": {
1109
     "end_time": "2019-06-19T08:16:13.770690Z",
1110
     "start_time": "2019-06-19T05:45:31.777404Z"
1111
    },
1112
    "scrolled": true
1113
   },
1114
   "outputs": [
1115
    {
1116
     "name": "stdout",
1117
     "output_type": "stream",
1118
     "text": [
1119
      "x_train shape is: (3679, 5, 8, 22)\n",
1120
      "y_train shape is: (3679, 5, 2)\n",
1121
      "x_test shape is: (409, 5, 8, 22)\n",
1122
      "y_test shape is: (409, 5, 2)\n",
1123
      "Update policy: top-k\n",
1124
      "Training with MI-RNN loss for 15 rounds\n",
1125
      "Round: 0\n",
1126
      "Epoch   1 Batch   110 (  225) Loss 0.08400 Acc 0.59375 | Val acc 0.58924 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1000\n",
1127
      "Epoch   1 Batch   110 (  225) Loss 0.08510 Acc 0.57500 | Val acc 0.63081 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1001\n",
1128
      "Epoch   1 Batch   110 (  225) Loss 0.07742 Acc 0.61250 | Val acc 0.66015 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1002\n",
1129
      "Epoch   1 Batch   110 (  225) Loss 0.07279 Acc 0.67500 | Val acc 0.66504 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1003\n",
1130
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1003\n",
1131
      "Round: 1\n",
1132
      "Epoch   1 Batch   110 (  225) Loss 0.06990 Acc 0.71875 | Val acc 0.67971 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1004\n",
1133
      "Epoch   1 Batch   110 (  225) Loss 0.06914 Acc 0.72500 | Val acc 0.68215 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1005\n",
1134
      "Epoch   1 Batch   110 (  225) Loss 0.06405 Acc 0.76250 | Val acc 0.69682 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1006\n",
1135
      "Epoch   1 Batch   110 (  225) Loss 0.06493 Acc 0.73125 | Val acc 0.70660 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1007\n",
1136
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1007\n",
1137
      "Round: 2\n",
1138
      "Epoch   1 Batch   110 (  225) Loss 0.06088 Acc 0.76875 | Val acc 0.72127 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1008\n",
1139
      "Epoch   1 Batch   110 (  225) Loss 0.05882 Acc 0.75000 | Val acc 0.73594 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1009\n",
1140
      "Epoch   1 Batch   110 (  225) Loss 0.05941 Acc 0.79375 | Val acc 0.74083 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1010\n",
1141
      "Epoch   1 Batch   110 (  225) Loss 0.06156 Acc 0.76875 | Val acc 0.74817 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1011\n",
1142
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1011\n",
1143
      "Round: 3\n",
1144
      "Epoch   1 Batch   110 (  225) Loss 0.06403 Acc 0.76250 | Val acc 0.75550 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1012\n",
1145
      "Epoch   1 Batch   110 (  225) Loss 0.05794 Acc 0.75625 | Val acc 0.75795 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1013\n",
1146
      "Epoch   1 Batch   110 (  225) Loss 0.05490 Acc 0.82500 | Val acc 0.76773 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1014\n",
1147
      "Epoch   1 Batch   110 (  225) Loss 0.05780 Acc 0.81250 | Val acc 0.78240 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1015\n",
1148
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1015\n",
1149
      "Round: 4\n",
1150
      "Epoch   1 Batch   110 (  225) Loss 0.05379 Acc 0.83750 | Val acc 0.77995 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1016\n",
1151
      "Epoch   1 Batch   110 (  225) Loss 0.05840 Acc 0.80000 | Val acc 0.78240 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1017\n",
1152
      "Epoch   1 Batch   110 (  225) Loss 0.05547 Acc 0.81250 | Val acc 0.77506 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1018\n",
1153
      "Epoch   1 Batch   110 (  225) Loss 0.05212 Acc 0.81875 | Val acc 0.78973 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1019\n",
1154
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1019\n",
1155
      "Round: 5\n",
1156
      "Epoch   1 Batch   110 (  225) Loss 0.04983 Acc 0.84375 | Val acc 0.78973 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1020\n",
1157
      "Epoch   1 Batch   110 (  225) Loss 0.04638 Acc 0.85000 | Val acc 0.80196 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1021\n",
1158
      "Epoch   1 Batch   110 (  225) Loss 0.04734 Acc 0.83750 | Val acc 0.80440 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1022\n",
1159
      "Epoch   1 Batch   110 (  225) Loss 0.05063 Acc 0.85625 | Val acc 0.80440 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1023\n",
1160
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1022\n",
1161
      "Round: 6\n",
1162
      "Epoch   1 Batch   110 (  225) Loss 0.04497 Acc 0.86875 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1024\n",
1163
      "Epoch   1 Batch   110 (  225) Loss 0.04768 Acc 0.83125 | Val acc 0.79951 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1025\n",
1164
      "Epoch   1 Batch   110 (  225) Loss 0.04779 Acc 0.81875 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1026\n",
1165
      "Epoch   1 Batch   110 (  225) Loss 0.04416 Acc 0.81250 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1027\n",
1166
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1027\n",
1167
      "Round: 7\n",
1168
      "Epoch   1 Batch   110 (  225) Loss 0.04235 Acc 0.84375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1028\n",
1169
      "Epoch   1 Batch   110 (  225) Loss 0.04155 Acc 0.84375 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1029\n",
1170
      "Epoch   1 Batch   110 (  225) Loss 0.03803 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1030\n",
1171
      "Epoch   1 Batch   110 (  225) Loss 0.04139 Acc 0.85000 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1031\n",
1172
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1031\n",
1173
      "Round: 8\n",
1174
      "Epoch   1 Batch   110 (  225) Loss 0.04385 Acc 0.81875 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1032\n",
1175
      "Epoch   1 Batch   110 (  225) Loss 0.03797 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1033\n",
1176
      "Epoch   1 Batch   110 (  225) Loss 0.03902 Acc 0.83750 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1034\n",
1177
      "Epoch   1 Batch   110 (  225) Loss 0.03691 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1035\n",
1178
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1032\n",
1179
      "Round: 9\n",
1180
      "Epoch   1 Batch   110 (  225) Loss 0.03872 Acc 0.87500 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1036\n",
1181
      "Epoch   1 Batch   110 (  225) Loss 0.04312 Acc 0.82500 | Val acc 0.81907 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1037\n",
1182
      "Epoch   1 Batch   110 (  225) Loss 0.03735 Acc 0.87500 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1038\n",
1183
      "Epoch   1 Batch   110 (  225) Loss 0.04043 Acc 0.84375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1039\n",
1184
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1039\n",
1185
      "Round: 10\n",
1186
      "Epoch   1 Batch   110 (  225) Loss 0.03696 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1040\n",
1187
      "Epoch   1 Batch   110 (  225) Loss 0.03923 Acc 0.86250 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1041\n",
1188
      "Epoch   1 Batch   110 (  225) Loss 0.03495 Acc 0.87500 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1042\n",
1189
      "Epoch   1 Batch   110 (  225) Loss 0.03587 Acc 0.88750 | Val acc 0.81907 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1043\n",
1190
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1042\n",
1191
      "Round: 11\n",
1192
      "Epoch   1 Batch   110 (  225) Loss 0.04023 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1044\n",
1193
      "Epoch   1 Batch   110 (  225) Loss 0.03841 Acc 0.85625 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1045\n",
1194
      "Epoch   1 Batch   110 (  225) Loss 0.03969 Acc 0.85000 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1046\n"
1195
     ]
1196
    },
1197
    {
1198
     "name": "stdout",
1199
     "output_type": "stream",
1200
     "text": [
1201
      "Epoch   1 Batch   110 (  225) Loss 0.03921 Acc 0.85000 | Val acc 0.81418 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1047\n",
1202
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1044\n",
1203
      "Round: 12\n",
1204
      "Epoch   1 Batch   110 (  225) Loss 0.04475 Acc 0.81875 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1048\n",
1205
      "Epoch   1 Batch   110 (  225) Loss 0.03966 Acc 0.85625 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1049\n",
1206
      "Epoch   1 Batch   110 (  225) Loss 0.03139 Acc 0.90000 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1050\n",
1207
      "Epoch   1 Batch   110 (  225) Loss 0.04275 Acc 0.81250 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1051\n",
1208
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1050\n",
1209
      "Round: 13\n",
1210
      "Epoch   1 Batch   110 (  225) Loss 0.03359 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1052\n",
1211
      "Epoch   1 Batch   110 (  225) Loss 0.03795 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1053\n",
1212
      "Epoch   1 Batch   110 (  225) Loss 0.03523 Acc 0.87500 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1054\n",
1213
      "Epoch   1 Batch   110 (  225) Loss 0.03383 Acc 0.89375 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1055\n",
1214
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1054\n",
1215
      "Round: 14\n",
1216
      "Epoch   1 Batch   110 (  225) Loss 0.02968 Acc 0.89375 | Val acc 0.81418 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1056\n",
1217
      "Epoch   1 Batch   110 (  225) Loss 0.03354 Acc 0.85625 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1057\n",
1218
      "Epoch   1 Batch   110 (  225) Loss 0.02782 Acc 0.90000 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1058\n",
1219
      "Epoch   1 Batch   110 (  225) Loss 0.02719 Acc 0.89375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1059\n",
1220
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1058\n",
1221
      "Round: 15\n",
1222
      "Switching to EMI-Loss function\n",
1223
      "Epoch   1 Batch   110 (  225) Loss 0.44879 Acc 0.84375 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1060\n",
1224
      "Epoch   1 Batch   110 (  225) Loss 0.42052 Acc 0.85000 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1061\n",
1225
      "Epoch   1 Batch   110 (  225) Loss 0.37942 Acc 0.87500 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1062\n",
1226
      "Epoch   1 Batch   110 (  225) Loss 0.38046 Acc 0.88125 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1063\n",
1227
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1063\n",
1228
      "Round: 16\n",
1229
      "Epoch   1 Batch   110 (  225) Loss 0.38822 Acc 0.86875 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1064\n",
1230
      "Epoch   1 Batch   110 (  225) Loss 0.37139 Acc 0.88750 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1065\n",
1231
      "Epoch   1 Batch   110 (  225) Loss 0.38185 Acc 0.88125 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1066\n",
1232
      "Epoch   1 Batch   110 (  225) Loss 0.37084 Acc 0.86875 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1067\n",
1233
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1065\n",
1234
      "Round: 17\n",
1235
      "Epoch   1 Batch   110 (  225) Loss 0.37820 Acc 0.89375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1068\n",
1236
      "Epoch   1 Batch   110 (  225) Loss 0.36451 Acc 0.84375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1069\n",
1237
      "Epoch   1 Batch   110 (  225) Loss 0.36738 Acc 0.89375 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1070\n",
1238
      "Epoch   1 Batch   110 (  225) Loss 0.40073 Acc 0.88750 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1071\n",
1239
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1068\n",
1240
      "Round: 18\n",
1241
      "Epoch   1 Batch   110 (  225) Loss 0.38454 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1072\n",
1242
      "Epoch   1 Batch   110 (  225) Loss 0.36183 Acc 0.90000 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1073\n",
1243
      "Epoch   1 Batch   110 (  225) Loss 0.38793 Acc 0.90000 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1074\n",
1244
      "Epoch   1 Batch   110 (  225) Loss 0.38047 Acc 0.83750 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1075\n",
1245
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1072\n",
1246
      "Round: 19\n",
1247
      "Epoch   1 Batch   110 (  225) Loss 0.38943 Acc 0.85625 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1076\n",
1248
      "Epoch   1 Batch   110 (  225) Loss 0.40676 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1077\n",
1249
      "Epoch   1 Batch   110 (  225) Loss 0.37826 Acc 0.90000 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1078\n",
1250
      "Epoch   1 Batch   110 (  225) Loss 0.39518 Acc 0.86250 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1079\n",
1251
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1078\n",
1252
      "Round: 20\n",
1253
      "Epoch   1 Batch   110 (  225) Loss 0.38738 Acc 0.85625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1080\n",
1254
      "Epoch   1 Batch   110 (  225) Loss 0.38468 Acc 0.85625 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1081\n",
1255
      "Epoch   1 Batch   110 (  225) Loss 0.35865 Acc 0.89375 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1082\n",
1256
      "Epoch   1 Batch   110 (  225) Loss 0.33802 Acc 0.92500 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1083\n",
1257
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1083\n",
1258
      "Round: 21\n",
1259
      "Epoch   1 Batch   110 (  225) Loss 0.35882 Acc 0.89375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1084\n",
1260
      "Epoch   1 Batch   110 (  225) Loss 0.37442 Acc 0.85625 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1085\n",
1261
      "Epoch   1 Batch   110 (  225) Loss 0.40794 Acc 0.87500 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1086\n",
1262
      "Epoch   1 Batch   110 (  225) Loss 0.33659 Acc 0.91250 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1087\n",
1263
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1084\n",
1264
      "Round: 22\n",
1265
      "Epoch   1 Batch   110 (  225) Loss 0.33654 Acc 0.88125 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1088\n",
1266
      "Epoch   1 Batch   110 (  225) Loss 0.36120 Acc 0.88125 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1089\n",
1267
      "Epoch   1 Batch   110 (  225) Loss 0.34685 Acc 0.88125 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1090\n",
1268
      "Epoch   1 Batch   110 (  225) Loss 0.32916 Acc 0.93125 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1091\n",
1269
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1091\n",
1270
      "Round: 23\n",
1271
      "Epoch   1 Batch   110 (  225) Loss 0.34161 Acc 0.90625 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1092\n",
1272
      "Epoch   1 Batch   110 (  225) Loss 0.39510 Acc 0.85000 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1093\n"
1273
     ]
1274
    },
1275
    {
1276
     "name": "stdout",
1277
     "output_type": "stream",
1278
     "text": [
1279
      "Epoch   1 Batch   110 (  225) Loss 0.37442 Acc 0.91875 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1094\n",
1280
      "Epoch   1 Batch   110 (  225) Loss 0.32392 Acc 0.91875 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1095\n",
1281
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1094\n",
1282
      "Round: 24\n",
1283
      "Epoch   1 Batch   110 (  225) Loss 0.32550 Acc 0.90625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1096\n",
1284
      "Epoch   1 Batch   110 (  225) Loss 0.34901 Acc 0.90625 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1097\n",
1285
      "Epoch   1 Batch   110 (  225) Loss 0.34062 Acc 0.90625 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1098\n",
1286
      "Epoch   1 Batch   110 (  225) Loss 0.33095 Acc 0.90625 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1099\n",
1287
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1099\n",
1288
      "Round: 25\n",
1289
      "Epoch   1 Batch   110 (  225) Loss 0.34793 Acc 0.90000 | Val acc 0.85819 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1100\n",
1290
      "Epoch   1 Batch   110 (  225) Loss 0.32091 Acc 0.92500 | Val acc 0.86553 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1101\n",
1291
      "Epoch   1 Batch   110 (  225) Loss 0.35738 Acc 0.90625 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1102\n",
1292
      "Epoch   1 Batch   110 (  225) Loss 0.34036 Acc 0.90000 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1103\n",
1293
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1101\n",
1294
      "Round: 26\n",
1295
      "Epoch   1 Batch   110 (  225) Loss 0.34306 Acc 0.88750 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1104\n",
1296
      "Epoch   1 Batch   110 (  225) Loss 0.32850 Acc 0.92500 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1105\n",
1297
      "Epoch   1 Batch   110 (  225) Loss 0.34266 Acc 0.88750 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1106\n",
1298
      "Epoch   1 Batch   110 (  225) Loss 0.34165 Acc 0.89375 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1107\n",
1299
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1105\n",
1300
      "Round: 27\n",
1301
      "Epoch   1 Batch   110 (  225) Loss 0.32461 Acc 0.89375 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1108\n",
1302
      "Epoch   1 Batch   110 (  225) Loss 0.33608 Acc 0.89375 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1109\n",
1303
      "Epoch   1 Batch   110 (  225) Loss 0.33137 Acc 0.90625 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1110\n",
1304
      "Epoch   1 Batch   110 (  225) Loss 0.28914 Acc 0.95000 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1111\n",
1305
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1110\n",
1306
      "Round: 28\n",
1307
      "Epoch   1 Batch   110 (  225) Loss 0.35324 Acc 0.92500 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1112\n",
1308
      "Epoch   1 Batch   110 (  225) Loss 0.31689 Acc 0.91875 | Val acc 0.86797 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1113\n",
1309
      "Epoch   1 Batch   110 (  225) Loss 0.32767 Acc 0.92500 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1114\n",
1310
      "Epoch   1 Batch   110 (  225) Loss 0.30775 Acc 0.91875 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1115\n",
1311
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1113\n",
1312
      "Round: 29\n",
1313
      "Epoch   1 Batch   110 (  225) Loss 0.33176 Acc 0.91250 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1116\n",
1314
      "Epoch   1 Batch   110 (  225) Loss 0.31769 Acc 0.90625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1117\n",
1315
      "Epoch   1 Batch   110 (  225) Loss 0.32659 Acc 0.90625 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1118\n",
1316
      "Epoch   1 Batch   110 (  225) Loss 0.31592 Acc 0.93125 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1119\n",
1317
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1118\n",
1318
      "Accuracy at k = 2: 0.863992\n",
1319
      "Savings due to MI-RNN : 0.600000\n",
1320
      "Savings due to Early prediction: 0.208659\n",
1321
      "Total Savings: 0.683464\n",
1322
      "   len       acc  macro-fsc  macro-pre  macro-rec  micro-fsc  micro-pre  \\\n",
1323
      "0    1  0.842466   0.842041   0.849946   0.844332   0.842466   0.842466   \n",
1324
      "1    2  0.863992   0.863982   0.865217   0.864776   0.863992   0.863992   \n",
1325
      "2    3  0.867906   0.867860   0.867807   0.867995   0.867906   0.867906   \n",
1326
      "3    4  0.864971   0.864697   0.865674   0.864385   0.864971   0.864971   \n",
1327
      "4    5  0.857143   0.856430   0.860480   0.855905   0.857143   0.857143   \n",
1328
      "\n",
1329
      "   micro-rec  fscore_01  \n",
1330
      "0   0.842466   0.850233  \n",
1331
      "1   0.863992   0.865179  \n",
1332
      "2   0.867906   0.865404  \n",
1333
      "3   0.864971   0.858607  \n",
1334
      "4   0.857143   0.846316  \n",
1335
      "Max accuracy 0.867906 at subsequencelength 3\n",
1336
      "Max micro-f 0.867906 at subsequencelength 3\n",
1337
      "Micro-precision 0.867906 at subsequencelength 3\n",
1338
      "Micro-recall 0.867906 at subsequencelength 3\n",
1339
      "Max macro-f 0.867860 at subsequencelength 3\n",
1340
      "macro-precision 0.867807 at subsequencelength 3\n",
1341
      "macro-recall 0.867995 at subsequencelength 3\n",
1342
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
1343
      "|    |   len |      acc |   macro-fsc |   macro-pre |   macro-rec |   micro-fsc |   micro-pre |   micro-rec |   fscore_01 |\n",
1344
      "+====+=======+==========+=============+=============+=============+=============+=============+=============+=============+\n",
1345
      "|  0 |     1 | 0.842466 |    0.842041 |    0.849946 |    0.844332 |    0.842466 |    0.842466 |    0.842466 |    0.850233 |\n",
1346
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
1347
      "|  1 |     2 | 0.863992 |    0.863982 |    0.865217 |    0.864776 |    0.863992 |    0.863992 |    0.863992 |    0.865179 |\n",
1348
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
1349
      "|  2 |     3 | 0.867906 |    0.86786  |    0.867807 |    0.867995 |    0.867906 |    0.867906 |    0.867906 |    0.865404 |\n",
1350
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
1351
      "|  3 |     4 | 0.864971 |    0.864697 |    0.865674 |    0.864385 |    0.864971 |    0.864971 |    0.864971 |    0.858607 |\n",
1352
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
1353
      "|  4 |     5 | 0.857143 |    0.85643  |    0.86048  |    0.855905 |    0.857143 |    0.857143 |    0.857143 |    0.846316 |\n",
1354
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
1355
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1003\n",
1356
      "Round:  0, Validation accuracy: 0.6650, Test Accuracy (k = 2): 0.684932, Additional savings: 0.006531\n",
1357
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1007\n",
1358
      "Round:  1, Validation accuracy: 0.7066, Test Accuracy (k = 2): 0.687867, Additional savings: 0.008072\n",
1359
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1011\n",
1360
      "Round:  2, Validation accuracy: 0.7482, Test Accuracy (k = 2): 0.721135, Additional savings: 0.014555\n",
1361
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1015\n",
1362
      "Round:  3, Validation accuracy: 0.7824, Test Accuracy (k = 2): 0.745597, Additional savings: 0.020597\n",
1363
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1019\n",
1364
      "Round:  4, Validation accuracy: 0.7897, Test Accuracy (k = 2): 0.776908, Additional savings: 0.027740\n",
1365
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1022\n"
1366
     ]
1367
    },
1368
    {
1369
     "name": "stdout",
1370
     "output_type": "stream",
1371
     "text": [
1372
      "Round:  5, Validation accuracy: 0.8044, Test Accuracy (k = 2): 0.793542, Additional savings: 0.037769\n",
1373
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1027\n",
1374
      "Round:  6, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.817025, Additional savings: 0.049315\n",
1375
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1031\n",
1376
      "Round:  7, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.826810, Additional savings: 0.054795\n",
1377
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1032\n",
1378
      "Round:  8, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.825832, Additional savings: 0.060372\n",
1379
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1039\n",
1380
      "Round:  9, Validation accuracy: 0.8264, Test Accuracy (k = 2): 0.821918, Additional savings: 0.068249\n",
1381
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1042\n",
1382
      "Round: 10, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.820939, Additional savings: 0.074609\n",
1383
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1044\n",
1384
      "Round: 11, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.830724, Additional savings: 0.078914\n",
1385
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1050\n",
1386
      "Round: 12, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.835616, Additional savings: 0.090411\n",
1387
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1054\n",
1388
      "Round: 13, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.835616, Additional savings: 0.098973\n",
1389
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1058\n",
1390
      "Round: 14, Validation accuracy: 0.8460, Test Accuracy (k = 2): 0.829746, Additional savings: 0.109638\n",
1391
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1063\n",
1392
      "Round: 15, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.827789, Additional savings: 0.111913\n",
1393
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1065\n",
1394
      "Round: 16, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.820939, Additional savings: 0.128156\n",
1395
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1068\n",
1396
      "Round: 17, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.830724, Additional savings: 0.129525\n",
1397
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1072\n",
1398
      "Round: 18, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.829746, Additional savings: 0.134760\n",
1399
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1078\n",
1400
      "Round: 19, Validation accuracy: 0.8386, Test Accuracy (k = 2): 0.837573, Additional savings: 0.143102\n",
1401
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1083\n",
1402
      "Round: 20, Validation accuracy: 0.8435, Test Accuracy (k = 2): 0.835616, Additional savings: 0.161717\n",
1403
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1084\n",
1404
      "Round: 21, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.850294, Additional savings: 0.167392\n",
1405
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1091\n",
1406
      "Round: 22, Validation accuracy: 0.8509, Test Accuracy (k = 2): 0.846380, Additional savings: 0.177886\n",
1407
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1094\n",
1408
      "Round: 23, Validation accuracy: 0.8484, Test Accuracy (k = 2): 0.860078, Additional savings: 0.175612\n",
1409
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1099\n",
1410
      "Round: 24, Validation accuracy: 0.8460, Test Accuracy (k = 2): 0.851272, Additional savings: 0.189750\n",
1411
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1101\n",
1412
      "Round: 25, Validation accuracy: 0.8655, Test Accuracy (k = 2): 0.854207, Additional savings: 0.193395\n",
1413
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1105\n",
1414
      "Round: 26, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.853229, Additional savings: 0.197114\n",
1415
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1110\n",
1416
      "Round: 27, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.851272, Additional savings: 0.199731\n",
1417
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1113\n",
1418
      "Round: 28, Validation accuracy: 0.8680, Test Accuracy (k = 2): 0.860078, Additional savings: 0.208586\n",
1419
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1118\n",
1420
      "Round: 29, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.863992, Additional savings: 0.208659\n",
1421
      "Results for this run have been saved at lstm .\n"
1422
     ]
1423
    }
1424
   ],
1425
   "source": [
1426
    "## Baseline EMI-LSTM\n",
1427
    "\n",
1428
    "dataset = 'SWELL-KW'\n",
1429
    "path = '/home/sf/data/SWELL-KW/8_3/'\n",
1430
    "\n",
1431
    "#Choose model from among [lstm, fastgrnn, gru]\n",
1432
    "model = 'lstm'\n",
1433
    "\n",
1434
    "# Dictionary to set the parameters.\n",
1435
    "params = {\n",
1436
    "    \"NUM_HIDDEN\" : 128,\n",
1437
    "    \"NUM_TIMESTEPS\" : 8, #subinstance length.\n",
1438
    "    \"ORIGINAL_NUM_TIMESTEPS\" : 20,\n",
1439
    "    \"NUM_FEATS\" : 22,\n",
1440
    "    \"FORGET_BIAS\" : 1.0,\n",
1441
    "    \"NUM_OUTPUT\" : 2,\n",
1442
    "    \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
1443
    "    \"KEEP_PROB\" : 0.75,\n",
1444
    "    \"PREFETCH_NUM\" : 5,\n",
1445
    "    \"BATCH_SIZE\" : 32,\n",
1446
    "    \"NUM_EPOCHS\" : 2,\n",
1447
    "    \"NUM_ITER\" : 4,\n",
1448
    "    \"NUM_ROUNDS\" : 30,\n",
1449
    "    \"LEARNING_RATE\" : 0.001,\n",
1450
    "    \"FRAC_EMI\" : 0.5,\n",
1451
    "    \"MODEL_PREFIX\" : '/home/sf/data/SWELL-KW/models/model-lstm'\n",
1452
    "}\n",
1453
    "\n",
1454
    "#Preprocess data, and load the train,test and validation splits.\n",
1455
    "lstm_dict = lstm_experiment_generator(params, path)\n",
1456
    "\n",
1457
    "#Create the directory to store the results of this run.\n",
1458
    "\n",
1459
    "dirname = \"\" + model\n",
1460
    "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
1461
    "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
1462
    "\n",
1463
    "now = datetime.datetime.now()\n",
1464
    "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
1465
    "filename = ''.join(filename)\n",
1466
    "\n",
1467
    "#Save the dictionary containing the params and the results.\n",
1468
    "pkl.dump(lstm_dict,open(dirname + \"/lstm_dict_\" + filename + \".pkl\",mode='wb'))"
1469
   ]
1470
  }
1471
 ],
1472
 "metadata": {
1473
  "kernelspec": {
1474
   "display_name": "Python 3",
1475
   "language": "python",
1476
   "name": "python3"
1477
  },
1478
  "language_info": {
1479
   "codemirror_mode": {
1480
    "name": "ipython",
1481
    "version": 3
1482
   },
1483
   "file_extension": ".py",
1484
   "mimetype": "text/x-python",
1485
   "name": "python",
1486
   "nbconvert_exporter": "python",
1487
   "pygments_lexer": "ipython3",
1488
   "version": "3.7.3"
1489
  }
1490
 },
1491
 "nbformat": 4,
1492
 "nbformat_minor": 2
1493
}