a b/SWELL-KW/SWELL-KW_GRU.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "# SWELL-KW GRU"
8
   ]
9
  },
10
  {
11
   "cell_type": "markdown",
12
   "metadata": {},
13
   "source": [
14
    "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
15
   ]
16
  },
17
  {
18
   "cell_type": "code",
19
   "execution_count": 1,
20
   "metadata": {},
21
   "outputs": [],
22
   "source": [
23
    "import pandas as pd\n",
24
    "import numpy as np\n",
25
    "from tabulate import tabulate\n",
26
    "import os\n",
27
    "import datetime as datetime\n",
28
    "import pickle as pkl\n",
29
    "import pathlib"
30
   ]
31
  },
32
  {
33
   "cell_type": "code",
34
   "execution_count": 2,
35
   "metadata": {
36
    "ExecuteTime": {
37
     "end_time": "2018-12-14T14:17:51.796585Z",
38
     "start_time": "2018-12-14T14:17:49.648375Z"
39
    }
40
   },
41
   "outputs": [
42
    {
43
     "name": "stderr",
44
     "output_type": "stream",
45
     "text": [
46
      "Using TensorFlow backend.\n"
47
     ]
48
    }
49
   ],
50
   "source": [
51
    "from __future__ import print_function\n",
52
    "import os\n",
53
    "import sys\n",
54
    "import tensorflow as tf\n",
55
    "import numpy as np\n",
56
    "# Making sure edgeml is part of python path\n",
57
    "sys.path.insert(0, '../../')\n",
58
    "#For processing on CPU.\n",
59
    "os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
60
    "\n",
61
    "np.random.seed(42)\n",
62
    "tf.set_random_seed(42)\n",
63
    "\n",
64
    "# MI-RNN and EMI-RNN imports\n",
65
    "from edgeml.graph.rnn import EMI_DataPipeline\n",
66
    "from edgeml.graph.rnn import EMI_GRU\n",
67
    "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
68
    "import edgeml.utils\n",
69
    "\n",
70
    "import keras.backend as K\n",
71
    "cfg = K.tf.ConfigProto()\n",
72
    "cfg.gpu_options.allow_growth = True\n",
73
    "K.set_session(K.tf.Session(config=cfg))"
74
   ]
75
  },
76
  {
77
   "cell_type": "code",
78
   "execution_count": 4,
79
   "metadata": {
80
    "ExecuteTime": {
81
     "end_time": "2018-12-14T14:17:51.803381Z",
82
     "start_time": "2018-12-14T14:17:51.798799Z"
83
    }
84
   },
85
   "outputs": [],
86
   "source": [
87
    "# Network parameters for our LSTM + FC Layer\n",
88
    "NUM_HIDDEN = 128\n",
89
    "NUM_TIMESTEPS = 8\n",
90
    "ORIGINAL_NUM_TIMESTEPS = 20\n",
91
    "NUM_FEATS = 22\n",
92
    "FORGET_BIAS = 1.0\n",
93
    "NUM_OUTPUT = 2\n",
94
    "USE_DROPOUT = True\n",
95
    "KEEP_PROB = 0.75\n",
96
    "\n",
97
    "# For dataset API\n",
98
    "PREFETCH_NUM = 5\n",
99
    "BATCH_SIZE = 32\n",
100
    "\n",
101
    "# Number of epochs in *one iteration*\n",
102
    "NUM_EPOCHS = 2\n",
103
    "# Number of iterations in *one round*. After each iteration,\n",
104
    "# the model is dumped to disk. At the end of the current\n",
105
    "# round, the best model among all the dumped models in the\n",
106
    "# current round is picked up..\n",
107
    "NUM_ITER = 4\n",
108
    "# A round consists of multiple training iterations and a belief\n",
109
    "# update step using the best model from all of these iterations\n",
110
    "NUM_ROUNDS = 30\n",
111
    "LEARNING_RATE=0.001\n",
112
    "\n",
113
    "# A staging direcory to store models\n",
114
    "MODEL_PREFIX = '/home/sf/data/SWELL-KW/models/GRU/model-gru'"
115
   ]
116
  },
117
  {
118
   "cell_type": "markdown",
119
   "metadata": {
120
    "heading_collapsed": true
121
   },
122
   "source": [
123
    "# Loading Data"
124
   ]
125
  },
126
  {
127
   "cell_type": "code",
128
   "execution_count": 5,
129
   "metadata": {
130
    "ExecuteTime": {
131
     "end_time": "2018-12-14T14:17:52.040352Z",
132
     "start_time": "2018-12-14T14:17:51.805319Z"
133
    },
134
    "hidden": true
135
   },
136
   "outputs": [
137
    {
138
     "name": "stdout",
139
     "output_type": "stream",
140
     "text": [
141
      "x_train shape is: (3679, 5, 8, 22)\n",
142
      "y_train shape is: (3679, 5, 2)\n",
143
      "x_test shape is: (409, 5, 8, 22)\n",
144
      "y_test shape is: (409, 5, 2)\n"
145
     ]
146
    }
147
   ],
148
   "source": [
149
    "# Loading the data\n",
150
    "x_train, y_train = np.load('/home/sf/data/SWELL-KW/8_3/x_train.npy'), np.load('/home/sf/data/SWELL-KW/8_3/y_train.npy')\n",
151
    "x_test, y_test = np.load('/home/sf/data/SWELL-KW/8_3/x_test.npy'), np.load('/home/sf/data/SWELL-KW/8_3/y_test.npy')\n",
152
    "x_val, y_val = np.load('/home/sf/data/SWELL-KW/8_3/x_val.npy'), np.load('/home/sf/data/SWELL-KW/8_3/y_val.npy')\n",
153
    "\n",
154
    "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
155
    "# step of EMI/MI RNN\n",
156
    "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
157
    "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
158
    "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
159
    "NUM_SUBINSTANCE = x_train.shape[1]\n",
160
    "print(\"x_train shape is:\", x_train.shape)\n",
161
    "print(\"y_train shape is:\", y_train.shape)\n",
162
    "print(\"x_test shape is:\", x_val.shape)\n",
163
    "print(\"y_test shape is:\", y_val.shape)"
164
   ]
165
  },
166
  {
167
   "cell_type": "markdown",
168
   "metadata": {},
169
   "source": [
170
    "# Computation Graph"
171
   ]
172
  },
173
  {
174
   "cell_type": "code",
175
   "execution_count": 6,
176
   "metadata": {
177
    "ExecuteTime": {
178
     "end_time": "2018-12-14T14:17:52.053161Z",
179
     "start_time": "2018-12-14T14:17:52.042928Z"
180
    }
181
   },
182
   "outputs": [],
183
   "source": [
184
    "# Define the linear secondary classifier\n",
185
    "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
186
    "    W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
187
    "    B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
188
    "    y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
189
    "    self.output = y_cap\n",
190
    "    self.graphCreated = True\n",
191
    "\n",
192
    "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
193
    "    y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
194
    "    self.output = y_cap\n",
195
    "    self.graphCreated = True\n",
196
    "    \n",
197
    "def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
198
    "    if inference is False:\n",
199
    "        feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
200
    "    else:\n",
201
    "        feedDict = {self._emiGraph.keep_prob: 1.0}\n",
202
    "    return feedDict\n",
203
    "    \n",
204
    "EMI_GRU._createExtendedGraph = createExtendedGraph\n",
205
    "EMI_GRU._restoreExtendedGraph = restoreExtendedGraph\n",
206
    "\n",
207
    "if USE_DROPOUT is True:\n",
208
    "    EMI_Driver.feedDictFunc = feedDictFunc"
209
   ]
210
  },
211
  {
212
   "cell_type": "code",
213
   "execution_count": 7,
214
   "metadata": {
215
    "ExecuteTime": {
216
     "end_time": "2018-12-14T14:17:52.335299Z",
217
     "start_time": "2018-12-14T14:17:52.055483Z"
218
    }
219
   },
220
   "outputs": [],
221
   "source": [
222
    "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
223
    "emiGRU = EMI_GRU(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
224
    "                        useDropout=USE_DROPOUT)\n",
225
    "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
226
    "                         stepSize=LEARNING_RATE)"
227
   ]
228
  },
229
  {
230
   "cell_type": "code",
231
   "execution_count": 8,
232
   "metadata": {
233
    "ExecuteTime": {
234
     "end_time": "2018-12-14T14:18:05.031382Z",
235
     "start_time": "2018-12-14T14:17:52.338750Z"
236
    }
237
   },
238
   "outputs": [],
239
   "source": [
240
    "tf.reset_default_graph()\n",
241
    "g1 = tf.Graph()    \n",
242
    "with g1.as_default():\n",
243
    "    # Obtain the iterators to each batch of the data\n",
244
    "    x_batch, y_batch = inputPipeline()\n",
245
    "    # Create the forward computation graph based on the iterators\n",
246
    "    y_cap = emiGRU(x_batch)\n",
247
    "    # Create loss graphs and training routines\n",
248
    "    emiTrainer(y_cap, y_batch)"
249
   ]
250
  },
251
  {
252
   "cell_type": "markdown",
253
   "metadata": {},
254
   "source": [
255
    "# EMI Driver"
256
   ]
257
  },
258
  {
259
   "cell_type": "code",
260
   "execution_count": 9,
261
   "metadata": {
262
    "ExecuteTime": {
263
     "end_time": "2018-12-14T14:35:15.209910Z",
264
     "start_time": "2018-12-14T14:18:05.034359Z"
265
    },
266
    "scrolled": true
267
   },
268
   "outputs": [
269
    {
270
     "name": "stdout",
271
     "output_type": "stream",
272
     "text": [
273
      "Update policy: top-k\n",
274
      "Training with MI-RNN loss for 15 rounds\n",
275
      "Round: 0\n",
276
      "Epoch   1 Batch   110 (  225) Loss 0.08950 Acc 0.63125 | Val acc 0.58680 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1000\n",
277
      "Epoch   1 Batch   110 (  225) Loss 0.08044 Acc 0.61875 | Val acc 0.60636 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1001\n",
278
      "Epoch   1 Batch   110 (  225) Loss 0.08588 Acc 0.59375 | Val acc 0.61369 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1002\n",
279
      "Epoch   1 Batch   110 (  225) Loss 0.06992 Acc 0.70625 | Val acc 0.63570 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1003\n",
280
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1003\n",
281
      "Round: 1\n",
282
      "Epoch   1 Batch   110 (  225) Loss 0.07058 Acc 0.65625 | Val acc 0.64792 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1004\n",
283
      "Epoch   1 Batch   110 (  225) Loss 0.06464 Acc 0.71250 | Val acc 0.66748 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1005\n",
284
      "Epoch   1 Batch   110 (  225) Loss 0.06516 Acc 0.70000 | Val acc 0.68704 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1006\n",
285
      "Epoch   1 Batch   110 (  225) Loss 0.06636 Acc 0.72500 | Val acc 0.69927 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1007\n",
286
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1007\n",
287
      "Round: 2\n",
288
      "Epoch   1 Batch   110 (  225) Loss 0.06905 Acc 0.68750 | Val acc 0.69438 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1008\n",
289
      "Epoch   1 Batch   110 (  225) Loss 0.06327 Acc 0.75625 | Val acc 0.68949 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1009\n",
290
      "Epoch   1 Batch   110 (  225) Loss 0.06798 Acc 0.75625 | Val acc 0.70416 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1010\n",
291
      "Epoch   1 Batch   110 (  225) Loss 0.06369 Acc 0.71875 | Val acc 0.69193 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1011\n",
292
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1010\n",
293
      "Round: 3\n",
294
      "Epoch   1 Batch   110 (  225) Loss 0.06522 Acc 0.73750 | Val acc 0.70171 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1012\n",
295
      "Epoch   1 Batch   110 (  225) Loss 0.06332 Acc 0.74375 | Val acc 0.70905 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1013\n",
296
      "Epoch   1 Batch   110 (  225) Loss 0.06348 Acc 0.75625 | Val acc 0.73105 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1014\n",
297
      "Epoch   1 Batch   110 (  225) Loss 0.06125 Acc 0.75625 | Val acc 0.72616 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1015\n",
298
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1014\n",
299
      "Round: 4\n",
300
      "Epoch   1 Batch   110 (  225) Loss 0.06360 Acc 0.73750 | Val acc 0.72616 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1016\n",
301
      "Epoch   1 Batch   110 (  225) Loss 0.06455 Acc 0.76250 | Val acc 0.74328 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1017\n",
302
      "Epoch   1 Batch   110 (  225) Loss 0.06291 Acc 0.76250 | Val acc 0.74328 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1018\n",
303
      "Epoch   1 Batch   110 (  225) Loss 0.05980 Acc 0.78750 | Val acc 0.75061 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1019\n",
304
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1019\n",
305
      "Round: 5\n",
306
      "Epoch   1 Batch   110 (  225) Loss 0.06897 Acc 0.76250 | Val acc 0.75061 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1020\n",
307
      "Epoch   1 Batch   110 (  225) Loss 0.06064 Acc 0.75625 | Val acc 0.76039 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1021\n",
308
      "Epoch   1 Batch   110 (  225) Loss 0.05520 Acc 0.78750 | Val acc 0.76773 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1022\n",
309
      "Epoch   1 Batch   110 (  225) Loss 0.05275 Acc 0.84375 | Val acc 0.75795 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1023\n",
310
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1022\n",
311
      "Round: 6\n",
312
      "Epoch   1 Batch   110 (  225) Loss 0.05810 Acc 0.78750 | Val acc 0.77017 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1024\n",
313
      "Epoch   1 Batch   110 (  225) Loss 0.05748 Acc 0.76875 | Val acc 0.76528 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1025\n",
314
      "Epoch   1 Batch   110 (  225) Loss 0.06343 Acc 0.81250 | Val acc 0.77751 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1026\n",
315
      "Epoch   1 Batch   110 (  225) Loss 0.05767 Acc 0.78125 | Val acc 0.76039 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1027\n",
316
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1026\n",
317
      "Round: 7\n",
318
      "Epoch   1 Batch   110 (  225) Loss 0.05564 Acc 0.80000 | Val acc 0.76528 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1028\n",
319
      "Epoch   1 Batch   110 (  225) Loss 0.06886 Acc 0.77500 | Val acc 0.76039 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1029\n",
320
      "Epoch   1 Batch   110 (  225) Loss 0.05440 Acc 0.78125 | Val acc 0.77751 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1030\n",
321
      "Epoch   1 Batch   110 (  225) Loss 0.05853 Acc 0.76875 | Val acc 0.78484 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1031\n",
322
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1031\n",
323
      "Round: 8\n",
324
      "Epoch   1 Batch   110 (  225) Loss 0.05785 Acc 0.75000 | Val acc 0.77751 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1032\n",
325
      "Epoch   1 Batch   110 (  225) Loss 0.05716 Acc 0.78750 | Val acc 0.80196 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1033\n",
326
      "Epoch   1 Batch   110 (  225) Loss 0.05379 Acc 0.81875 | Val acc 0.79707 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1034\n",
327
      "Epoch   1 Batch   110 (  225) Loss 0.05206 Acc 0.81875 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1035\n",
328
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1035\n",
329
      "Round: 9\n",
330
      "Epoch   1 Batch   110 (  225) Loss 0.05389 Acc 0.80625 | Val acc 0.81418 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1036\n",
331
      "Epoch   1 Batch   110 (  225) Loss 0.05023 Acc 0.79375 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1037\n",
332
      "Epoch   1 Batch   110 (  225) Loss 0.04696 Acc 0.80625 | Val acc 0.81907 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1038\n",
333
      "Epoch   1 Batch   110 (  225) Loss 0.04655 Acc 0.82500 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1039\n",
334
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1039\n",
335
      "Round: 10\n",
336
      "Epoch   1 Batch   110 (  225) Loss 0.04700 Acc 0.83125 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1040\n",
337
      "Epoch   1 Batch   110 (  225) Loss 0.04426 Acc 0.83125 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1041\n",
338
      "Epoch   1 Batch   110 (  225) Loss 0.05274 Acc 0.81875 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1042\n",
339
      "Epoch   1 Batch   110 (  225) Loss 0.04178 Acc 0.82500 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1043\n",
340
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1043\n",
341
      "Round: 11\n",
342
      "Epoch   1 Batch   110 (  225) Loss 0.05123 Acc 0.78125 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1044\n",
343
      "Epoch   1 Batch   110 (  225) Loss 0.05073 Acc 0.82500 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1045\n",
344
      "Epoch   1 Batch   110 (  225) Loss 0.04487 Acc 0.82500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1046\n",
345
      "Epoch   1 Batch   110 (  225) Loss 0.04353 Acc 0.84375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1047\n",
346
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1044\n",
347
      "Round: 12\n",
348
      "Epoch   1 Batch   110 (  225) Loss 0.04450 Acc 0.85625 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1048\n",
349
      "Epoch   1 Batch   110 (  225) Loss 0.04483 Acc 0.84375 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1049\n",
350
      "Epoch   1 Batch   110 (  225) Loss 0.03620 Acc 0.86250 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1050\n",
351
      "Epoch   1 Batch   110 (  225) Loss 0.03725 Acc 0.87500 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1051\n",
352
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1048\n",
353
      "Round: 13\n",
354
      "Epoch   1 Batch   110 (  225) Loss 0.04427 Acc 0.81250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1052\n",
355
      "Epoch   1 Batch   110 (  225) Loss 0.03978 Acc 0.87500 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1053\n",
356
      "Epoch   1 Batch   110 (  225) Loss 0.04293 Acc 0.83125 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1054\n",
357
      "Epoch   1 Batch   110 (  225) Loss 0.03750 Acc 0.85625 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1055\n",
358
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1054\n",
359
      "Round: 14\n",
360
      "Epoch   1 Batch   110 (  225) Loss 0.04008 Acc 0.83125 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1056\n",
361
      "Epoch   1 Batch   110 (  225) Loss 0.03921 Acc 0.85000 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1057\n",
362
      "Epoch   1 Batch   110 (  225) Loss 0.03854 Acc 0.85625 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1058\n",
363
      "Epoch   1 Batch   110 (  225) Loss 0.03663 Acc 0.88125 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1059\n",
364
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1057\n",
365
      "Round: 15\n",
366
      "Switching to EMI-Loss function\n",
367
      "Epoch   1 Batch   110 (  225) Loss 0.45096 Acc 0.84375 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1060\n",
368
      "Epoch   1 Batch   110 (  225) Loss 0.47173 Acc 0.79375 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1061\n",
369
      "Epoch   1 Batch   110 (  225) Loss 0.43876 Acc 0.83750 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1062\n",
370
      "Epoch   1 Batch   110 (  225) Loss 0.43396 Acc 0.80625 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1063\n",
371
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1063\n",
372
      "Round: 16\n",
373
      "Epoch   1 Batch   110 (  225) Loss 0.43497 Acc 0.84375 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1064\n",
374
      "Epoch   1 Batch   110 (  225) Loss 0.42715 Acc 0.83750 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1065\n",
375
      "Epoch   1 Batch   110 (  225) Loss 0.42176 Acc 0.82500 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1066\n",
376
      "Epoch   1 Batch   110 (  225) Loss 0.39219 Acc 0.84375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1067\n",
377
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1064\n",
378
      "Round: 17\n",
379
      "Epoch   1 Batch   110 (  225) Loss 0.39438 Acc 0.85625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1068\n",
380
      "Epoch   1 Batch   110 (  225) Loss 0.43058 Acc 0.82500 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1069\n",
381
      "Epoch   1 Batch   110 (  225) Loss 0.40885 Acc 0.87500 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1070\n",
382
      "Epoch   1 Batch   110 (  225) Loss 0.39980 Acc 0.83125 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1071\n",
383
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1069\n",
384
      "Round: 18\n",
385
      "Epoch   1 Batch   110 (  225) Loss 0.40417 Acc 0.84375 | Val acc 0.86064 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1072\n",
386
      "Epoch   1 Batch   110 (  225) Loss 0.43128 Acc 0.83125 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1073\n",
387
      "Epoch   1 Batch   110 (  225) Loss 0.40664 Acc 0.86250 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1074\n",
388
      "Epoch   1 Batch   110 (  225) Loss 0.41514 Acc 0.86250 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1075\n",
389
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1072\n",
390
      "Round: 19\n",
391
      "Epoch   1 Batch   110 (  225) Loss 0.38532 Acc 0.85000 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1076\n",
392
      "Epoch   1 Batch   110 (  225) Loss 0.40468 Acc 0.85625 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1077\n",
393
      "Epoch   1 Batch   110 (  225) Loss 0.44360 Acc 0.82500 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1078\n",
394
      "Epoch   1 Batch   110 (  225) Loss 0.39450 Acc 0.85625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1079\n",
395
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1078\n",
396
      "Round: 20\n",
397
      "Epoch   1 Batch   110 (  225) Loss 0.37850 Acc 0.83750 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1080\n",
398
      "Epoch   1 Batch   110 (  225) Loss 0.37629 Acc 0.86250 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1081\n",
399
      "Epoch   1 Batch   110 (  225) Loss 0.38814 Acc 0.85625 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1082\n",
400
      "Epoch   1 Batch   110 (  225) Loss 0.37590 Acc 0.90000 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1083\n",
401
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1080\n",
402
      "Round: 21\n",
403
      "Epoch   1 Batch   110 (  225) Loss 0.39236 Acc 0.87500 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1084\n",
404
      "Epoch   1 Batch   110 (  225) Loss 0.38403 Acc 0.86875 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1085\n",
405
      "Epoch   1 Batch   110 (  225) Loss 0.39141 Acc 0.86875 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1086\n",
406
      "Epoch   1 Batch   110 (  225) Loss 0.37983 Acc 0.85625 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1087\n",
407
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1084\n",
408
      "Round: 22\n",
409
      "Epoch   1 Batch   110 (  225) Loss 0.40008 Acc 0.87500 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1088\n",
410
      "Epoch   1 Batch   110 (  225) Loss 0.34944 Acc 0.89375 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1089\n",
411
      "Epoch   1 Batch   110 (  225) Loss 0.39081 Acc 0.89375 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1090\n",
412
      "Epoch   1 Batch   110 (  225) Loss 0.38600 Acc 0.84375 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1091\n",
413
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1089\n"
414
     ]
415
    },
416
    {
417
     "name": "stdout",
418
     "output_type": "stream",
419
     "text": [
420
      "Round: 23\n",
421
      "Epoch   1 Batch   110 (  225) Loss 0.39712 Acc 0.86250 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1092\n",
422
      "Epoch   1 Batch   110 (  225) Loss 0.35570 Acc 0.89375 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1093\n",
423
      "Epoch   1 Batch   110 (  225) Loss 0.35381 Acc 0.88750 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1094\n",
424
      "Epoch   1 Batch   110 (  225) Loss 0.37666 Acc 0.85625 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1095\n",
425
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1092\n",
426
      "Round: 24\n",
427
      "Epoch   1 Batch   110 (  225) Loss 0.38746 Acc 0.88750 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1096\n",
428
      "Epoch   1 Batch   110 (  225) Loss 0.36476 Acc 0.91250 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1097\n",
429
      "Epoch   1 Batch   110 (  225) Loss 0.37721 Acc 0.88750 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1098\n",
430
      "Epoch   1 Batch   110 (  225) Loss 0.35381 Acc 0.90625 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1099\n",
431
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1096\n",
432
      "Round: 25\n",
433
      "Epoch   1 Batch   110 (  225) Loss 0.39612 Acc 0.85625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1100\n",
434
      "Epoch   1 Batch   110 (  225) Loss 0.38607 Acc 0.83750 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1101\n",
435
      "Epoch   1 Batch   110 (  225) Loss 0.37138 Acc 0.90000 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1102\n",
436
      "Epoch   1 Batch   110 (  225) Loss 0.36966 Acc 0.87500 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1103\n",
437
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1101\n",
438
      "Round: 26\n",
439
      "Epoch   1 Batch   110 (  225) Loss 0.33496 Acc 0.94375 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1104\n",
440
      "Epoch   1 Batch   110 (  225) Loss 0.35933 Acc 0.87500 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1105\n",
441
      "Epoch   1 Batch   110 (  225) Loss 0.35526 Acc 0.88750 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1106\n",
442
      "Epoch   1 Batch   110 (  225) Loss 0.36529 Acc 0.87500 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1107\n",
443
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1107\n",
444
      "Round: 27\n",
445
      "Epoch   1 Batch   110 (  225) Loss 0.35815 Acc 0.90625 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1108\n",
446
      "Epoch   1 Batch   110 (  225) Loss 0.35819 Acc 0.88125 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1109\n",
447
      "Epoch   1 Batch   110 (  225) Loss 0.36880 Acc 0.90000 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1110\n",
448
      "Epoch   1 Batch   110 (  225) Loss 0.35225 Acc 0.90625 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1111\n",
449
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1108\n",
450
      "Round: 28\n",
451
      "Epoch   1 Batch   110 (  225) Loss 0.36809 Acc 0.88750 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1112\n",
452
      "Epoch   1 Batch   110 (  225) Loss 0.36660 Acc 0.88125 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1113\n",
453
      "Epoch   1 Batch   110 (  225) Loss 0.33627 Acc 0.88125 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1114\n",
454
      "Epoch   1 Batch   110 (  225) Loss 0.35610 Acc 0.88750 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1115\n",
455
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1114\n",
456
      "Round: 29\n",
457
      "Epoch   1 Batch   110 (  225) Loss 0.35923 Acc 0.87500 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1116\n",
458
      "Epoch   1 Batch   110 (  225) Loss 0.34627 Acc 0.89375 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1117\n",
459
      "Epoch   1 Batch   110 (  225) Loss 0.36224 Acc 0.88750 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1118\n",
460
      "Epoch   1 Batch   110 (  225) Loss 0.36262 Acc 0.89375 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/GRU/model-gru, global_step 1119\n",
461
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1116\n"
462
     ]
463
    }
464
   ],
465
   "source": [
466
    "with g1.as_default():\n",
467
    "    emiDriver = EMI_Driver(inputPipeline, emiGRU, emiTrainer)\n",
468
    "\n",
469
    "emiDriver.initializeSession(g1)\n",
470
    "y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
471
    "                                      y_train=y_train, bag_train=BAG_TRAIN,\n",
472
    "                                      x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
473
    "                                      numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
474
    "                                      numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
475
    "                                      numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
476
    "                                      fracEMI=0.5, updatePolicy='top-k', k=1)"
477
   ]
478
  },
479
  {
480
   "cell_type": "markdown",
481
   "metadata": {},
482
   "source": [
483
    "# Evaluating the  trained model"
484
   ]
485
  },
486
  {
487
   "cell_type": "code",
488
   "execution_count": 10,
489
   "metadata": {
490
    "ExecuteTime": {
491
     "end_time": "2018-12-14T14:35:15.218040Z",
492
     "start_time": "2018-12-14T14:35:15.211771Z"
493
    }
494
   },
495
   "outputs": [],
496
   "source": [
497
    "# Early Prediction Policy: We make an early prediction based on the predicted classes\n",
498
    "#     probability. If the predicted class probability > minProb at some step, we make\n",
499
    "#     a prediction at that step.\n",
500
    "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
501
    "    assert instanceOut.ndim == 2\n",
502
    "    classes = np.argmax(instanceOut, axis=1)\n",
503
    "    prob = np.max(instanceOut, axis=1)\n",
504
    "    index = np.where(prob >= minProb)[0]\n",
505
    "    if len(index) == 0:\n",
506
    "        assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
507
    "        return classes[-1], len(instanceOut) - 1\n",
508
    "    index = index[0]\n",
509
    "    return classes[index], index\n",
510
    "\n",
511
    "def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
512
    "    predictionStep = predictionStep + 1\n",
513
    "    predictionStep = np.reshape(predictionStep, -1)\n",
514
    "    totalSteps = np.sum(predictionStep)\n",
515
    "    maxSteps = len(predictionStep) * numTimeSteps\n",
516
    "    savings = 1.0 - (totalSteps / maxSteps)\n",
517
    "    if returnTotal:\n",
518
    "        return savings, totalSteps\n",
519
    "    return savings"
520
   ]
521
  },
522
  {
523
   "cell_type": "code",
524
   "execution_count": 11,
525
   "metadata": {
526
    "ExecuteTime": {
527
     "end_time": "2018-12-14T14:35:16.257489Z",
528
     "start_time": "2018-12-14T14:35:15.221029Z"
529
    }
530
   },
531
   "outputs": [
532
    {
533
     "name": "stdout",
534
     "output_type": "stream",
535
     "text": [
536
      "Accuracy at k = 2: 0.842466\n",
537
      "Savings due to MI-RNN : 0.600000\n",
538
      "Savings due to Early prediction: 0.162573\n",
539
      "Total Savings: 0.665029\n"
540
     ]
541
    }
542
   ],
543
   "source": [
544
    "k = 2\n",
545
    "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
546
    "                                                               minProb=0.99, keep_prob=1.0)\n",
547
    "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
548
    "print('Accuracy at k = %d: %f' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
549
    "mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
550
    "emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
551
    "total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
552
    "print('Savings due to MI-RNN : %f' % mi_savings)\n",
553
    "print('Savings due to Early prediction: %f' % emi_savings)\n",
554
    "print('Total Savings: %f' % (total_savings))"
555
   ]
556
  },
557
  {
558
   "cell_type": "code",
559
   "execution_count": 12,
560
   "metadata": {
561
    "ExecuteTime": {
562
     "end_time": "2018-12-14T14:35:17.044115Z",
563
     "start_time": "2018-12-14T14:35:16.259280Z"
564
    },
565
    "scrolled": true
566
   },
567
   "outputs": [
568
    {
569
     "name": "stdout",
570
     "output_type": "stream",
571
     "text": [
572
      "   len       acc  macro-fsc  macro-pre  macro-rec  micro-fsc  micro-pre  \\\n",
573
      "0    1  0.831703   0.831557   0.835146   0.832988   0.831703   0.831703   \n",
574
      "1    2  0.842466   0.842440   0.842466   0.842688   0.842466   0.842466   \n",
575
      "2    3  0.846380   0.846156   0.846630   0.845956   0.846380   0.846380   \n",
576
      "3    4  0.847358   0.846679   0.850044   0.846213   0.847358   0.847358   \n",
577
      "4    5  0.839530   0.838276   0.845458   0.837832   0.839530   0.839530   \n",
578
      "\n",
579
      "   micro-rec  fscore_01  \n",
580
      "0   0.831703   0.836502  \n",
581
      "1   0.842466   0.840436  \n",
582
      "2   0.846380   0.840285  \n",
583
      "3   0.847358   0.836478  \n",
584
      "4   0.839530   0.824034  \n",
585
      "Max accuracy 0.847358 at subsequencelength 4\n",
586
      "Max micro-f 0.847358 at subsequencelength 4\n",
587
      "Micro-precision 0.847358 at subsequencelength 4\n",
588
      "Micro-recall 0.847358 at subsequencelength 4\n",
589
      "Max macro-f 0.846679 at subsequencelength 4\n",
590
      "macro-precision 0.850044 at subsequencelength 4\n",
591
      "macro-recall 0.846213 at subsequencelength 4\n"
592
     ]
593
    }
594
   ],
595
   "source": [
596
    "# A slightly more detailed analysis method is provided. \n",
597
    "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)"
598
   ]
599
  },
600
  {
601
   "cell_type": "markdown",
602
   "metadata": {},
603
   "source": [
604
    "## Picking the best model"
605
   ]
606
  },
607
  {
608
   "cell_type": "code",
609
   "execution_count": 18,
610
   "metadata": {
611
    "ExecuteTime": {
612
     "end_time": "2018-12-14T14:35:54.899340Z",
613
     "start_time": "2018-12-14T14:35:17.047464Z"
614
    },
615
    "scrolled": true
616
   },
617
   "outputs": [
618
    {
619
     "name": "stdout",
620
     "output_type": "stream",
621
     "text": [
622
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1003\n",
623
      "Round:  0, Validation accuracy: 0.6357, Test Accuracy (k = 4): 0.667319, Total Savings: 0.602926\n",
624
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1007\n",
625
      "Round:  1, Validation accuracy: 0.6993, Test Accuracy (k = 4): 0.693738, Total Savings: 0.603063\n",
626
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1010\n",
627
      "Round:  2, Validation accuracy: 0.7042, Test Accuracy (k = 4): 0.685910, Total Savings: 0.603875\n",
628
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1014\n",
629
      "Round:  3, Validation accuracy: 0.7311, Test Accuracy (k = 4): 0.693738, Total Savings: 0.603806\n",
630
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1019\n",
631
      "Round:  4, Validation accuracy: 0.7506, Test Accuracy (k = 4): 0.709393, Total Savings: 0.604521\n",
632
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1022\n",
633
      "Round:  5, Validation accuracy: 0.7677, Test Accuracy (k = 4): 0.726027, Total Savings: 0.605832\n",
634
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1026\n",
635
      "Round:  6, Validation accuracy: 0.7775, Test Accuracy (k = 4): 0.745597, Total Savings: 0.607730\n",
636
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1031\n",
637
      "Round:  7, Validation accuracy: 0.7848, Test Accuracy (k = 4): 0.772016, Total Savings: 0.609295\n",
638
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1035\n",
639
      "Round:  8, Validation accuracy: 0.8215, Test Accuracy (k = 4): 0.777886, Total Savings: 0.611252\n",
640
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1039\n",
641
      "Round:  9, Validation accuracy: 0.8264, Test Accuracy (k = 4): 0.798434, Total Savings: 0.614902\n",
642
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1043\n",
643
      "Round: 10, Validation accuracy: 0.8362, Test Accuracy (k = 4): 0.799413, Total Savings: 0.617515\n",
644
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1044\n",
645
      "Round: 11, Validation accuracy: 0.8460, Test Accuracy (k = 4): 0.814090, Total Savings: 0.618190\n",
646
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1048\n",
647
      "Round: 12, Validation accuracy: 0.8509, Test Accuracy (k = 4): 0.810176, Total Savings: 0.619305\n",
648
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1054\n",
649
      "Round: 13, Validation accuracy: 0.8509, Test Accuracy (k = 4): 0.825832, Total Savings: 0.620881\n",
650
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1057\n",
651
      "Round: 14, Validation accuracy: 0.8460, Test Accuracy (k = 4): 0.828767, Total Savings: 0.623914\n",
652
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1063\n",
653
      "Round: 15, Validation accuracy: 0.8533, Test Accuracy (k = 4): 0.823875, Total Savings: 0.634061\n",
654
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1064\n",
655
      "Round: 16, Validation accuracy: 0.8435, Test Accuracy (k = 4): 0.836595, Total Savings: 0.637857\n",
656
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1069\n",
657
      "Round: 17, Validation accuracy: 0.8509, Test Accuracy (k = 4): 0.828767, Total Savings: 0.641301\n",
658
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1072\n",
659
      "Round: 18, Validation accuracy: 0.8606, Test Accuracy (k = 4): 0.831703, Total Savings: 0.640753\n",
660
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1078\n",
661
      "Round: 19, Validation accuracy: 0.8509, Test Accuracy (k = 4): 0.832681, Total Savings: 0.646027\n",
662
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1080\n",
663
      "Round: 20, Validation accuracy: 0.8533, Test Accuracy (k = 4): 0.836595, Total Savings: 0.648141\n",
664
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1084\n",
665
      "Round: 21, Validation accuracy: 0.8533, Test Accuracy (k = 4): 0.846380, Total Savings: 0.648679\n",
666
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1089\n",
667
      "Round: 22, Validation accuracy: 0.8509, Test Accuracy (k = 4): 0.839530, Total Savings: 0.651624\n",
668
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1092\n",
669
      "Round: 23, Validation accuracy: 0.8533, Test Accuracy (k = 4): 0.835616, Total Savings: 0.653933\n",
670
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1096\n",
671
      "Round: 24, Validation accuracy: 0.8557, Test Accuracy (k = 4): 0.839530, Total Savings: 0.655998\n",
672
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1101\n",
673
      "Round: 25, Validation accuracy: 0.8509, Test Accuracy (k = 4): 0.833659, Total Savings: 0.657808\n",
674
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1107\n",
675
      "Round: 26, Validation accuracy: 0.8533, Test Accuracy (k = 4): 0.840509, Total Savings: 0.661654\n",
676
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1108\n",
677
      "Round: 27, Validation accuracy: 0.8533, Test Accuracy (k = 4): 0.841487, Total Savings: 0.663053\n",
678
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1114\n",
679
      "Round: 28, Validation accuracy: 0.8557, Test Accuracy (k = 4): 0.839530, Total Savings: 0.663738\n",
680
      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/GRU/model-gru-1116\n",
681
      "Round: 29, Validation accuracy: 0.8509, Test Accuracy (k = 4): 0.842466, Total Savings: 0.665029\n"
682
     ]
683
    }
684
   ],
685
   "source": [
686
    "devnull = open(os.devnull, 'r')\n",
687
    "for val in modelStats:\n",
688
    "    round_, acc, modelPrefix, globalStep = val\n",
689
    "    emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
690
    "    predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
691
    "                                                               minProb=0.99, keep_prob=1.0)\n",
692
    "\n",
693
    "    bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
694
    "    print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
695
    "    print(', Test Accuracy (k = %d): %f, ' % (4,  np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
696
    "    mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
697
    "    emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
698
    "    total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
699
    "    print(\"Total Savings: %f\" % total_savings)"
700
   ]
701
  },
702
  {
703
   "cell_type": "code",
704
   "execution_count": 14,
705
   "metadata": {},
706
   "outputs": [],
707
   "source": [
708
    "params = {\n",
709
    "    \"NUM_HIDDEN\" : 128,\n",
710
    "    \"NUM_TIMESTEPS\" : 64, #subinstance length.\n",
711
    "    \"ORIGINAL_NUM_TIMESTEPS\" : 128,\n",
712
    "    \"NUM_FEATS\" : 16,\n",
713
    "    \"FORGET_BIAS\" : 1.0,\n",
714
    "    \"NUM_OUTPUT\" : 5,\n",
715
    "    \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
716
    "    \"KEEP_PROB\" : 0.75,\n",
717
    "    \"PREFETCH_NUM\" : 5,\n",
718
    "    \"BATCH_SIZE\" : 32,\n",
719
    "    \"NUM_EPOCHS\" : 2,\n",
720
    "    \"NUM_ITER\" : 4,\n",
721
    "    \"NUM_ROUNDS\" : 10,\n",
722
    "    \"LEARNING_RATE\" : 0.001,\n",
723
    "    \"MODEL_PREFIX\" : '/home/sf/data/DREAMER/Dominance/model-gru'\n",
724
    "}"
725
   ]
726
  },
727
  {
728
   "cell_type": "code",
729
   "execution_count": 15,
730
   "metadata": {},
731
   "outputs": [
732
    {
733
     "name": "stdout",
734
     "output_type": "stream",
735
     "text": [
736
      "   len       acc  macro-fsc  macro-pre  macro-rec  micro-fsc  micro-pre  \\\n",
737
      "0    1  0.831703   0.831557   0.835146   0.832988   0.831703   0.831703   \n",
738
      "1    2  0.842466   0.842440   0.842466   0.842688   0.842466   0.842466   \n",
739
      "2    3  0.846380   0.846156   0.846630   0.845956   0.846380   0.846380   \n",
740
      "3    4  0.847358   0.846679   0.850044   0.846213   0.847358   0.847358   \n",
741
      "4    5  0.839530   0.838276   0.845458   0.837832   0.839530   0.839530   \n",
742
      "\n",
743
      "   micro-rec  fscore_01  \n",
744
      "0   0.831703   0.836502  \n",
745
      "1   0.842466   0.840436  \n",
746
      "2   0.846380   0.840285  \n",
747
      "3   0.847358   0.836478  \n",
748
      "4   0.839530   0.824034  \n",
749
      "Max accuracy 0.847358 at subsequencelength 4\n",
750
      "Max micro-f 0.847358 at subsequencelength 4\n",
751
      "Micro-precision 0.847358 at subsequencelength 4\n",
752
      "Micro-recall 0.847358 at subsequencelength 4\n",
753
      "Max macro-f 0.846679 at subsequencelength 4\n",
754
      "macro-precision 0.850044 at subsequencelength 4\n",
755
      "macro-recall 0.846213 at subsequencelength 4\n",
756
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
757
      "|    |   len |      acc |   macro-fsc |   macro-pre |   macro-rec |   micro-fsc |   micro-pre |   micro-rec |   fscore_01 |\n",
758
      "+====+=======+==========+=============+=============+=============+=============+=============+=============+=============+\n",
759
      "|  0 |     1 | 0.831703 |    0.831557 |    0.835146 |    0.832988 |    0.831703 |    0.831703 |    0.831703 |    0.836502 |\n",
760
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
761
      "|  1 |     2 | 0.842466 |    0.84244  |    0.842466 |    0.842688 |    0.842466 |    0.842466 |    0.842466 |    0.840436 |\n",
762
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
763
      "|  2 |     3 | 0.84638  |    0.846156 |    0.84663  |    0.845956 |    0.84638  |    0.84638  |    0.84638  |    0.840285 |\n",
764
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
765
      "|  3 |     4 | 0.847358 |    0.846679 |    0.850044 |    0.846213 |    0.847358 |    0.847358 |    0.847358 |    0.836478 |\n",
766
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
767
      "|  4 |     5 | 0.83953  |    0.838276 |    0.845458 |    0.837832 |    0.83953  |    0.83953  |    0.83953  |    0.824034 |\n",
768
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n"
769
     ]
770
    }
771
   ],
772
   "source": [
773
    "gru_dict = {**params}\n",
774
    "gru_dict[\"k\"] = k\n",
775
    "gru_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
776
    "gru_dict[\"total_savings\"] = total_savings\n",
777
    "gru_dict[\"y_test\"] = BAG_TEST\n",
778
    "gru_dict[\"y_pred\"] = bagPredictions\n",
779
    "\n",
780
    "# A slightly more detailed analysis method is provided. \n",
781
    "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
782
    "print (tabulate(df, headers=list(df.columns), tablefmt='grid'))"
783
   ]
784
  },
785
  {
786
   "cell_type": "code",
787
   "execution_count": 16,
788
   "metadata": {},
789
   "outputs": [
790
    {
791
     "name": "stdout",
792
     "output_type": "stream",
793
     "text": [
794
      "Results for this run have been saved at /home/sf/data/SWELL/GRU/ .\n"
795
     ]
796
    }
797
   ],
798
   "source": [
799
    "dirname = \"/home/sf/data/SWELL/GRU/\"\n",
800
    "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
801
    "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
802
    "\n",
803
    "now = datetime.datetime.now()\n",
804
    "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
805
    "filename = ''.join(filename)\n",
806
    "\n",
807
    "#Save the dictionary containing the params and the results.\n",
808
    "pkl.dump(gru_dict,open(dirname  + filename + \".pkl\",mode='wb'))"
809
   ]
810
  },
811
  {
812
   "cell_type": "code",
813
   "execution_count": 17,
814
   "metadata": {},
815
   "outputs": [
816
    {
817
     "data": {
818
      "text/plain": [
819
       "'/home/sf/data/SWELL/GRU/2019-9-3|15-1.pkl'"
820
      ]
821
     },
822
     "execution_count": 17,
823
     "metadata": {},
824
     "output_type": "execute_result"
825
    }
826
   ],
827
   "source": [
828
    "dirname+filename+'.pkl'"
829
   ]
830
  }
831
 ],
832
 "metadata": {
833
  "kernelspec": {
834
   "display_name": "Python 3",
835
   "language": "python",
836
   "name": "python3"
837
  },
838
  "language_info": {
839
   "codemirror_mode": {
840
    "name": "ipython",
841
    "version": 3
842
   },
843
   "file_extension": ".py",
844
   "mimetype": "text/x-python",
845
   "name": "python",
846
   "nbconvert_exporter": "python",
847
   "pygments_lexer": "ipython3",
848
   "version": "3.7.3"
849
  },
850
  "latex_envs": {
851
   "LaTeX_envs_menu_present": true,
852
   "autoclose": false,
853
   "autocomplete": true,
854
   "bibliofile": "biblio.bib",
855
   "cite_by": "apalike",
856
   "current_citInitial": 1,
857
   "eqLabelWithNumbers": true,
858
   "eqNumInitial": 1,
859
   "hotkeys": {
860
    "equation": "Ctrl-E",
861
    "itemize": "Ctrl-I"
862
   },
863
   "labels_anchors": false,
864
   "latex_user_defs": false,
865
   "report_style_numbering": false,
866
   "user_envs_cfg": false
867
  }
868
 },
869
 "nbformat": 4,
870
 "nbformat_minor": 2
871
}