Switch to side-by-side view

--- a
+++ b/SWELL-KW/SWELL-KW_LSTM.ipynb
@@ -0,0 +1,1493 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# SWELL-KW LSTM"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Imports"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-07-15T18:30:17.522073Z",
+     "start_time": "2019-07-15T18:30:17.217355Z"
+    },
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "from tabulate import tabulate\n",
+    "import os\n",
+    "import datetime as datetime\n",
+    "import pickle as pkl\n",
+    "import pathlib"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## DataFrames from CSVs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 224,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-07-15T18:30:28.212447Z",
+     "start_time": "2019-07-15T18:30:17.889941Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "df = pd.read_csv('/home/sf/data/SWELL-KW/EDA_New.csv', index_col=0)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Preprocessing "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 226,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "df['condition'].replace('time pressure', 'interruption', inplace=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 227,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "df['condition'].replace('interruption', 1, inplace=True)\n",
+    "df['condition'].replace('no stress', 0, inplace=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Split Ground Truth "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 188,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0.027338218248412812"
+      ]
+     },
+     "execution_count": 188,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "df[df['condition']==1]['RANGE'].mean()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 229,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-07-15T18:30:37.760226Z",
+     "start_time": "2019-07-15T18:30:37.597908Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "filtered_train = df.drop(['condition'], axis=1)\n",
+    "filtered_target = df['condition']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 231,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-07-15T18:32:27.581110Z",
+     "start_time": "2019-07-15T18:32:27.576162Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(104040,)\n",
+      "(104040, 22)\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(filtered_target.shape)\n",
+    "print(filtered_train.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 232,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0    52695\n",
+       "1    51345\n",
+       "Name: condition, dtype: int64"
+      ]
+     },
+     "execution_count": 232,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pd.value_counts(filtered_target)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 233,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:24:41.449331Z",
+     "start_time": "2019-06-19T05:24:41.445741Z"
+    },
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "y = filtered_target.values.reshape(-1, 20)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Convert to 3D - (Bags, Timesteps, Features)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 234,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "22"
+      ]
+     },
+     "execution_count": 234,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(filtered_train.columns)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 235,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:24:43.963421Z",
+     "start_time": "2019-06-19T05:24:43.944207Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(104040, 22)\n",
+      "(5202, 20, 22)\n"
+     ]
+    }
+   ],
+   "source": [
+    "x = filtered_train.values\n",
+    "print(x.shape)\n",
+    "x = x.reshape(int(len(x) / 20), 20, 22)\n",
+    "print(x.shape)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Filter Overlapping Bags"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 236,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:19.299273Z",
+     "start_time": "2019-06-19T05:25:18.952971Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[120, 198, 289, 408, 410, 415, 544, 667, 776, 894, 908, 911, 913, 1019, 1103, 1117, 1147, 1150, 1191, 1305, 1440, 1549, 1643, 1660, 1685, 1805, 1937, 2042, 2175, 2179, 2183, 2184, 2384, 2490, 2566, 2571, 2594, 2597, 2637, 2638, 2761, 2810, 2918, 2990, 3004, 3008, 3015, 3056, 3164, 3236, 3244, 3274, 3277, 3318, 3433, 3546, 3556, 3560, 3562, 3563, 3564, 3676, 3786, 3796, 3800, 3839, 3840, 3841, 3956, 4051, 4081, 4084, 4125, 4126, 4232, 4310, 4411, 4527, 4544, 4546, 4549, 4554, 4559, 4564, 4687, 4789, 4895, 4968, 4985, 4989, 5008, 5131]\n"
+     ]
+    }
+   ],
+   "source": [
+    "# filtering bags that overlap with another class\n",
+    "bags_to_remove = []\n",
+    "for i in range(len(y)):\n",
+    "    if len(set(y[i])) > 1:\n",
+    "        bags_to_remove.append(i)\n",
+    "print(bags_to_remove)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 237,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:26.446153Z",
+     "start_time": "2019-06-19T05:25:26.256245Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "x = np.delete(x, bags_to_remove, axis=0)\n",
+    "y = np.delete(y, bags_to_remove, axis=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 238,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:27.260726Z",
+     "start_time": "2019-06-19T05:25:27.254474Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(5110, 20, 22)"
+      ]
+     },
+     "execution_count": 238,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x.shape "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 239,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:28.101475Z",
+     "start_time": "2019-06-19T05:25:28.096009Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(5110, 20)"
+      ]
+     },
+     "execution_count": 239,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y.shape"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Categorical Representation "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 240,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:50.094089Z",
+     "start_time": "2019-06-19T05:25:49.746284Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "one_hot_list = []\n",
+    "for i in range(len(y)):\n",
+    "    one_hot_list.append(set(y[i]).pop())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 241,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:52.203633Z",
+     "start_time": "2019-06-19T05:25:52.198467Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "categorical_y_ver = one_hot_list\n",
+    "categorical_y_ver = np.array(categorical_y_ver)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 242,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:53.132006Z",
+     "start_time": "2019-06-19T05:25:53.126314Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(5110,)"
+      ]
+     },
+     "execution_count": 242,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "categorical_y_ver.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 243,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:25:54.495163Z",
+     "start_time": "2019-06-19T05:25:54.489349Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "20"
+      ]
+     },
+     "execution_count": 243,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x.shape[1]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 244,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:26:08.021392Z",
+     "start_time": "2019-06-19T05:26:08.017038Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def one_hot(y, numOutput):\n",
+    "    y = np.reshape(y, [-1])\n",
+    "    ret = np.zeros([y.shape[0], numOutput])\n",
+    "    for i, label in enumerate(y):\n",
+    "        ret[i, label] = 1\n",
+    "    return ret"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Extract 3D Normalized Data with Validation Set"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 245,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:26:08.931435Z",
+     "start_time": "2019-06-19T05:26:08.927397Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "from sklearn.model_selection import train_test_split\n",
+    "import pathlib"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 246,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:26:10.295822Z",
+     "start_time": "2019-06-19T05:26:09.832723Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "x_train_val_combined, x_test, y_train_val_combined, y_test = train_test_split(x, categorical_y_ver, test_size=0.20, random_state=42)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 247,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:26:11.260084Z",
+     "start_time": "2019-06-19T05:26:11.253714Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array([1, 1, 0, ..., 1, 0, 1])"
+      ]
+     },
+     "execution_count": 247,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y_test"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 248,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-18T17:24:41.832337Z",
+     "start_time": "2019-06-18T17:24:41.141678Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "extractedDir = '/home/sf/data/SWELL-KW/'\n",
+    "# def generateData(extractedDir):\n",
+    "# x_train_val_combined, x_test, y_train_val_combined, y_test = readData(extractedDir)\n",
+    "timesteps = x_train_val_combined.shape[-2] #128/256\n",
+    "feats = x_train_val_combined.shape[-1]  #16\n",
+    "\n",
+    "trainSize = int(x_train_val_combined.shape[0]*0.9) #6566\n",
+    "x_train, x_val = x_train_val_combined[:trainSize], x_train_val_combined[trainSize:] \n",
+    "y_train, y_val = y_train_val_combined[:trainSize], y_train_val_combined[trainSize:]\n",
+    "\n",
+    "# normalization\n",
+    "x_train = np.reshape(x_train, [-1, feats])\n",
+    "mean = np.mean(x_train, axis=0)\n",
+    "std = np.std(x_train, axis=0)\n",
+    "\n",
+    "# normalize train\n",
+    "x_train = x_train - mean\n",
+    "x_train = x_train / std\n",
+    "x_train = np.reshape(x_train, [-1, timesteps, feats])\n",
+    "\n",
+    "# normalize val\n",
+    "x_val = np.reshape(x_val, [-1, feats])\n",
+    "x_val = x_val - mean\n",
+    "x_val = x_val / std\n",
+    "x_val = np.reshape(x_val, [-1, timesteps, feats])\n",
+    "\n",
+    "# normalize test\n",
+    "x_test = np.reshape(x_test, [-1, feats])\n",
+    "x_test = x_test - mean\n",
+    "x_test = x_test / std\n",
+    "x_test = np.reshape(x_test, [-1, timesteps, feats])\n",
+    "\n",
+    "# shuffle test, as this was remaining\n",
+    "idx = np.arange(len(x_test))\n",
+    "np.random.shuffle(idx)\n",
+    "x_test = x_test[idx]\n",
+    "y_test = y_test[idx]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 249,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-18T17:25:50.674962Z",
+     "start_time": "2019-06-18T17:25:50.481068Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "/home/sf/data/SWELL-KW//\n"
+     ]
+    }
+   ],
+   "source": [
+    "# one-hot encoding of labels\n",
+    "numOutput = 2\n",
+    "y_train = one_hot(y_train, numOutput)\n",
+    "y_val = one_hot(y_val, numOutput)\n",
+    "y_test = one_hot(y_test, numOutput)\n",
+    "extractedDir += '/'\n",
+    "\n",
+    "pathlib.Path(extractedDir + 'RAW').mkdir(parents=True, exist_ok = True)\n",
+    "\n",
+    "np.save(extractedDir + \"RAW/x_train\", x_train)\n",
+    "np.save(extractedDir + \"RAW/y_train\", y_train)\n",
+    "np.save(extractedDir + \"RAW/x_test\", x_test)\n",
+    "np.save(extractedDir + \"RAW/y_test\", y_test)\n",
+    "np.save(extractedDir + \"RAW/x_val\", x_val)\n",
+    "np.save(extractedDir + \"RAW/y_val\", y_val)\n",
+    "\n",
+    "print(extractedDir)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 250,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-18T17:26:35.381650Z",
+     "start_time": "2019-06-18T17:26:35.136645Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "x_test.npy  x_train.npy  x_val.npy  y_test.npy  y_train.npy  y_val.npy\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "ls /home/sf/data/SWELL-KW/RAW"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 251,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-18T17:27:34.458130Z",
+     "start_time": "2019-06-18T17:27:34.323712Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(3679, 20, 22)"
+      ]
+     },
+     "execution_count": 251,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "np.load('//home/sf/data/SWELL-KW/RAW/x_train.npy').shape"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Make 4D EMI Data (Bags, Subinstances, Subinstance Length, Features)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:30:53.720179Z",
+     "start_time": "2019-06-19T05:30:53.713756Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def loadData(dirname):\n",
+    "    x_train = np.load(dirname + '/' + 'x_train.npy')\n",
+    "    y_train = np.load(dirname + '/' + 'y_train.npy')\n",
+    "    x_test = np.load(dirname + '/' + 'x_test.npy')\n",
+    "    y_test = np.load(dirname + '/' + 'y_test.npy')\n",
+    "    x_val = np.load(dirname + '/' + 'x_val.npy')\n",
+    "    y_val = np.load(dirname + '/' + 'y_val.npy')\n",
+    "    return x_train, y_train, x_test, y_test, x_val, y_val"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-18T17:29:23.614052Z",
+     "start_time": "2019-06-18T17:29:23.601324Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def bagData(X, Y, subinstanceLen, subinstanceStride):\n",
+    "    numClass = 2\n",
+    "    numSteps = 20\n",
+    "    numFeats = 22\n",
+    "    assert X.ndim == 3\n",
+    "    assert X.shape[1] == numSteps\n",
+    "    assert X.shape[2] == numFeats\n",
+    "    assert subinstanceLen <= numSteps\n",
+    "    assert subinstanceLen > 0\n",
+    "    assert subinstanceStride <= numSteps\n",
+    "    assert subinstanceStride >= 0\n",
+    "    assert len(X) == len(Y)\n",
+    "    assert Y.ndim == 2\n",
+    "    assert Y.shape[1] == numClass\n",
+    "    x_bagged = []\n",
+    "    y_bagged = []\n",
+    "    for i, point in enumerate(X[:, :, :]):\n",
+    "        instanceList = []\n",
+    "        start = 0\n",
+    "        end = subinstanceLen\n",
+    "        while True:\n",
+    "            x = point[start:end, :]\n",
+    "            if len(x) < subinstanceLen:\n",
+    "                x_ = np.zeros([subinstanceLen, x.shape[1]])\n",
+    "                x_[:len(x), :] = x[:, :]\n",
+    "                x = x_\n",
+    "            instanceList.append(x)\n",
+    "            if end >= numSteps:\n",
+    "                break\n",
+    "            start += subinstanceStride\n",
+    "            end += subinstanceStride\n",
+    "        bag = np.array(instanceList)\n",
+    "        numSubinstance = bag.shape[0]\n",
+    "        label = Y[i]\n",
+    "        label = np.argmax(label)\n",
+    "        labelBag = np.zeros([numSubinstance, numClass])\n",
+    "        labelBag[:, label] = 1\n",
+    "        x_bagged.append(bag)\n",
+    "        label = np.array(labelBag)\n",
+    "        y_bagged.append(label)\n",
+    "    return np.array(x_bagged), np.array(y_bagged)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:33:06.531994Z",
+     "start_time": "2019-06-19T05:33:06.523884Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir):\n",
+    "    x_train, y_train, x_test, y_test, x_val, y_val = loadData(sourceDir)\n",
+    "    x, y = bagData(x_train, y_train, subinstanceLen, subinstanceStride)\n",
+    "    np.save(outDir + '/x_train.npy', x)\n",
+    "    np.save(outDir + '/y_train.npy', y)\n",
+    "    print('Num train %d' % len(x))\n",
+    "    x, y = bagData(x_test, y_test, subinstanceLen, subinstanceStride)\n",
+    "    np.save(outDir + '/x_test.npy', x)\n",
+    "    np.save(outDir + '/y_test.npy', y)\n",
+    "    print('Num test %d' % len(x))\n",
+    "    x, y = bagData(x_val, y_val, subinstanceLen, subinstanceStride)\n",
+    "    np.save(outDir + '/x_val.npy', x)\n",
+    "    np.save(outDir + '/y_val.npy', y)\n",
+    "    print('Num val %d' % len(x))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:35:20.960014Z",
+     "start_time": "2019-06-19T05:35:20.050363Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Num train 3679\n",
+      "Num test 1022\n",
+      "Num val 409\n"
+     ]
+    }
+   ],
+   "source": [
+    "subinstanceLen = 8\n",
+    "subinstanceStride = 3\n",
+    "extractedDir = '/home/sf/data/SWELL-KW'\n",
+    "from os import mkdir\n",
+    "#mkdir('/home/sf/data/SWELL-KW/8_3')\n",
+    "rawDir = extractedDir + '/RAW'\n",
+    "sourceDir = rawDir\n",
+    "outDir = extractedDir + '/%d_%d/' % (subinstanceLen, subinstanceStride)\n",
+    "makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-18T17:48:58.293843Z",
+     "start_time": "2019-06-18T17:48:58.285383Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(3679, 5, 2)"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "np.load('/home/sf/data/SWELL-KW/8_3/y_train.npy').shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:35:48.609552Z",
+     "start_time": "2019-06-19T05:35:48.604291Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "from edgeml.graph.rnn import EMI_DataPipeline\n",
+    "from edgeml.graph.rnn import EMI_BasicLSTM, EMI_FastGRNN, EMI_FastRNN, EMI_GRU\n",
+    "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
+    "import edgeml.utils"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:43:48.032413Z",
+     "start_time": "2019-06-19T05:43:47.987839Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def lstm_experiment_generator(params, path = './DSAAR/64_16/'):\n",
+    "    \"\"\"\n",
+    "        Function that will generate the experiments to be run.\n",
+    "        Inputs : \n",
+    "        (1) Dictionary params, to set the network parameters.\n",
+    "        (2) Name of the Model to be run from [EMI-LSTM, EMI-FastGRNN, EMI-GRU]\n",
+    "        (3) Path to the dataset, where the csv files are present.\n",
+    "    \"\"\"\n",
+    "    \n",
+    "    #Copy the contents of the params dictionary.\n",
+    "    lstm_dict = {**params}\n",
+    "    \n",
+    "    #---------------------------PARAM SETTING----------------------#\n",
+    "    \n",
+    "    # Network parameters for our LSTM + FC Layer\n",
+    "    NUM_HIDDEN = params[\"NUM_HIDDEN\"]\n",
+    "    NUM_TIMESTEPS = params[\"NUM_TIMESTEPS\"]\n",
+    "    ORIGINAL_NUM_TIMESTEPS = params[\"ORIGINAL_NUM_TIMESTEPS\"]\n",
+    "    NUM_FEATS = params[\"NUM_FEATS\"]\n",
+    "    FORGET_BIAS = params[\"FORGET_BIAS\"]\n",
+    "    NUM_OUTPUT = params[\"NUM_OUTPUT\"]\n",
+    "    USE_DROPOUT = True if (params[\"USE_DROPOUT\"] == 1) else False\n",
+    "    KEEP_PROB = params[\"KEEP_PROB\"]\n",
+    "\n",
+    "    # For dataset API\n",
+    "    PREFETCH_NUM = params[\"PREFETCH_NUM\"]\n",
+    "    BATCH_SIZE = params[\"BATCH_SIZE\"]\n",
+    "\n",
+    "    # Number of epochs in *one iteration*\n",
+    "    NUM_EPOCHS = params[\"NUM_EPOCHS\"]\n",
+    "    # Number of iterations in *one round*. After each iteration,\n",
+    "    # the model is dumped to disk. At the end of the current\n",
+    "    # round, the best model among all the dumped models in the\n",
+    "    # current round is picked up..\n",
+    "    NUM_ITER = params[\"NUM_ITER\"]\n",
+    "    # A round consists of multiple training iterations and a belief\n",
+    "    # update step using the best model from all of these iterations\n",
+    "    NUM_ROUNDS = params[\"NUM_ROUNDS\"]\n",
+    "    LEARNING_RATE = params[\"LEARNING_RATE\"]\n",
+    "\n",
+    "    # A staging direcory to store models\n",
+    "    MODEL_PREFIX = params[\"MODEL_PREFIX\"]\n",
+    "    \n",
+    "    #----------------------END OF PARAM SETTING----------------------#\n",
+    "    \n",
+    "    #----------------------DATA LOADING------------------------------#\n",
+    "    \n",
+    "    x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
+    "    x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
+    "    x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
+    "\n",
+    "    # BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
+    "    # step of EMI/MI RNN\n",
+    "    BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
+    "    BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
+    "    BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
+    "    NUM_SUBINSTANCE = x_train.shape[1]\n",
+    "    print(\"x_train shape is:\", x_train.shape)\n",
+    "    print(\"y_train shape is:\", y_train.shape)\n",
+    "    print(\"x_test shape is:\", x_val.shape)\n",
+    "    print(\"y_test shape is:\", y_val.shape)\n",
+    "    \n",
+    "    #----------------------END OF DATA LOADING------------------------------#    \n",
+    "    \n",
+    "    #----------------------COMPUTATION GRAPH--------------------------------#\n",
+    "    \n",
+    "    # Define the linear secondary classifier\n",
+    "    def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
+    "        W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
+    "        B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
+    "        y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
+    "        self.output = y_cap\n",
+    "        self.graphCreated = True\n",
+    "\n",
+    "    def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
+    "        y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
+    "        self.output = y_cap\n",
+    "        self.graphCreated = True\n",
+    "\n",
+    "    def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
+    "        if inference is False:\n",
+    "            feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
+    "        else:\n",
+    "            feedDict = {self._emiGraph.keep_prob: 1.0}\n",
+    "        return feedDict\n",
+    "\n",
+    "    EMI_BasicLSTM._createExtendedGraph = createExtendedGraph\n",
+    "    EMI_BasicLSTM._restoreExtendedGraph = restoreExtendedGraph\n",
+    "\n",
+    "    if USE_DROPOUT is True:\n",
+    "        EMI_Driver.feedDictFunc = feedDictFunc\n",
+    "    \n",
+    "    inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
+    "    emiLSTM = EMI_BasicLSTM(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
+    "                            forgetBias=FORGET_BIAS, useDropout=USE_DROPOUT)\n",
+    "    emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
+    "                             stepSize=LEARNING_RATE)\n",
+    "    \n",
+    "    tf.reset_default_graph()\n",
+    "    g1 = tf.Graph()    \n",
+    "    with g1.as_default():\n",
+    "        # Obtain the iterators to each batch of the data\n",
+    "        x_batch, y_batch = inputPipeline()\n",
+    "        # Create the forward computation graph based on the iterators\n",
+    "        y_cap = emiLSTM(x_batch)\n",
+    "        # Create loss graphs and training routines\n",
+    "        emiTrainer(y_cap, y_batch)\n",
+    "        \n",
+    "    #------------------------------END OF COMPUTATION GRAPH------------------------------#\n",
+    "    \n",
+    "    #-------------------------------------EMI DRIVER-------------------------------------#\n",
+    "        \n",
+    "    with g1.as_default():\n",
+    "        emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n",
+    "\n",
+    "    emiDriver.initializeSession(g1)\n",
+    "    y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
+    "                                          y_train=y_train, bag_train=BAG_TRAIN,\n",
+    "                                          x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
+    "                                          numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
+    "                                          numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
+    "                                          numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
+    "                                          fracEMI=0.5, updatePolicy='top-k', k=1)\n",
+    "    \n",
+    "    #-------------------------------END OF EMI DRIVER-------------------------------------#\n",
+    "    \n",
+    "    #-----------------------------------EARLY SAVINGS-------------------------------------#\n",
+    "    \n",
+    "    \"\"\"\n",
+    "        Early Prediction Policy: We make an early prediction based on the predicted classes\n",
+    "        probability. If the predicted class probability > minProb at some step, we make\n",
+    "        a prediction at that step.\n",
+    "    \"\"\"\n",
+    "    def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
+    "        assert instanceOut.ndim == 2\n",
+    "        classes = np.argmax(instanceOut, axis=1)\n",
+    "        prob = np.max(instanceOut, axis=1)\n",
+    "        index = np.where(prob >= minProb)[0]\n",
+    "        if len(index) == 0:\n",
+    "            assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
+    "            return classes[-1], len(instanceOut) - 1\n",
+    "        index = index[0]\n",
+    "        return classes[index], index\n",
+    "\n",
+    "    def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
+    "        predictionStep = predictionStep + 1\n",
+    "        predictionStep = np.reshape(predictionStep, -1)\n",
+    "        totalSteps = np.sum(predictionStep)\n",
+    "        maxSteps = len(predictionStep) * numTimeSteps\n",
+    "        savings = 1.0 - (totalSteps / maxSteps)\n",
+    "        if returnTotal:\n",
+    "            return savings, totalSteps\n",
+    "        return savings\n",
+    "    \n",
+    "    #--------------------------------END OF EARLY SAVINGS---------------------------------#\n",
+    "    \n",
+    "    #----------------------------------------BEST MODEL-----------------------------------#\n",
+    "    \n",
+    "    k = 2\n",
+    "    predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
+    "                                                                   minProb=0.99, keep_prob=1.0)\n",
+    "    bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
+    "    print('Accuracy at k = %d: %f' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
+    "    mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
+    "    emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
+    "    total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
+    "    print('Savings due to MI-RNN : %f' % mi_savings)\n",
+    "    print('Savings due to Early prediction: %f' % emi_savings)\n",
+    "    print('Total Savings: %f' % (total_savings))\n",
+    "    \n",
+    "    #Store in the dictionary.\n",
+    "    lstm_dict[\"k\"] = k\n",
+    "    lstm_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
+    "    lstm_dict[\"total_savings\"] = total_savings\n",
+    "    lstm_dict[\"y_test\"] = BAG_TEST\n",
+    "    lstm_dict[\"y_pred\"] = bagPredictions\n",
+    "    \n",
+    "    # A slightly more detailed analysis method is provided. \n",
+    "    df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
+    "    print (tabulate(df, headers=list(df.columns), tablefmt='grid'))\n",
+    "    \n",
+    "    lstm_dict[\"detailed analysis\"] = df\n",
+    "    #----------------------------------END OF BEST MODEL-----------------------------------#\n",
+    "    \n",
+    "    #----------------------------------PICKING THE BEST MODEL------------------------------#\n",
+    "    \n",
+    "    devnull = open(os.devnull, 'r')\n",
+    "    for val in modelStats:\n",
+    "        round_, acc, modelPrefix, globalStep = val\n",
+    "        emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
+    "        predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
+    "                                                                   minProb=0.99, keep_prob=1.0)\n",
+    "\n",
+    "        bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
+    "        print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
+    "        print(', Test Accuracy (k = %d): %f, ' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
+    "        print('Additional savings: %f' % getEarlySaving(predictionStep, NUM_TIMESTEPS)) \n",
+    "        \n",
+    "    \n",
+    "    #-------------------------------END OF PICKING THE BEST MODEL--------------------------#\n",
+    "\n",
+    "    return lstm_dict"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:43:51.798834Z",
+     "start_time": "2019-06-19T05:43:51.792938Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def experiment_generator(params, path, model = 'lstm'):\n",
+    "    \n",
+    "    \n",
+    "    if (model == 'lstm'): return lstm_experiment_generator(params, path)\n",
+    "    elif (model == 'fastgrnn'): return fastgrnn_experiment_generator(params, path)\n",
+    "    elif (model == 'gru'): return gru_experiment_generator(params, path)\n",
+    "    elif (model == 'baseline'): return baseline_experiment_generator(params, path)\n",
+    "    \n",
+    "    return "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'/home/sf/data/EdgeML/tf/examples/EMI-RNN'"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pwd"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T05:43:53.168413Z",
+     "start_time": "2019-06-19T05:43:53.164622Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "import tensorflow as tf"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2019-06-19T08:16:13.770690Z",
+     "start_time": "2019-06-19T05:45:31.777404Z"
+    },
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "x_train shape is: (3679, 5, 8, 22)\n",
+      "y_train shape is: (3679, 5, 2)\n",
+      "x_test shape is: (409, 5, 8, 22)\n",
+      "y_test shape is: (409, 5, 2)\n",
+      "Update policy: top-k\n",
+      "Training with MI-RNN loss for 15 rounds\n",
+      "Round: 0\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.08400 Acc 0.59375 | Val acc 0.58924 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1000\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.08510 Acc 0.57500 | Val acc 0.63081 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1001\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.07742 Acc 0.61250 | Val acc 0.66015 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1002\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.07279 Acc 0.67500 | Val acc 0.66504 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1003\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1003\n",
+      "Round: 1\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.06990 Acc 0.71875 | Val acc 0.67971 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1004\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.06914 Acc 0.72500 | Val acc 0.68215 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1005\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.06405 Acc 0.76250 | Val acc 0.69682 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1006\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.06493 Acc 0.73125 | Val acc 0.70660 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1007\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1007\n",
+      "Round: 2\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.06088 Acc 0.76875 | Val acc 0.72127 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1008\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05882 Acc 0.75000 | Val acc 0.73594 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1009\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05941 Acc 0.79375 | Val acc 0.74083 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1010\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.06156 Acc 0.76875 | Val acc 0.74817 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1011\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1011\n",
+      "Round: 3\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.06403 Acc 0.76250 | Val acc 0.75550 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1012\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05794 Acc 0.75625 | Val acc 0.75795 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1013\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05490 Acc 0.82500 | Val acc 0.76773 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1014\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05780 Acc 0.81250 | Val acc 0.78240 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1015\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1015\n",
+      "Round: 4\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05379 Acc 0.83750 | Val acc 0.77995 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1016\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05840 Acc 0.80000 | Val acc 0.78240 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1017\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05547 Acc 0.81250 | Val acc 0.77506 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1018\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05212 Acc 0.81875 | Val acc 0.78973 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1019\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1019\n",
+      "Round: 5\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04983 Acc 0.84375 | Val acc 0.78973 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1020\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04638 Acc 0.85000 | Val acc 0.80196 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1021\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04734 Acc 0.83750 | Val acc 0.80440 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1022\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.05063 Acc 0.85625 | Val acc 0.80440 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1023\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1022\n",
+      "Round: 6\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04497 Acc 0.86875 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1024\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04768 Acc 0.83125 | Val acc 0.79951 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1025\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04779 Acc 0.81875 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1026\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04416 Acc 0.81250 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1027\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1027\n",
+      "Round: 7\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04235 Acc 0.84375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1028\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04155 Acc 0.84375 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1029\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03803 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1030\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04139 Acc 0.85000 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1031\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1031\n",
+      "Round: 8\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04385 Acc 0.81875 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1032\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03797 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1033\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03902 Acc 0.83750 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1034\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03691 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1035\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1032\n",
+      "Round: 9\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03872 Acc 0.87500 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1036\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04312 Acc 0.82500 | Val acc 0.81907 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1037\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03735 Acc 0.87500 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1038\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04043 Acc 0.84375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1039\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1039\n",
+      "Round: 10\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03696 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1040\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03923 Acc 0.86250 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1041\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03495 Acc 0.87500 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1042\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03587 Acc 0.88750 | Val acc 0.81907 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1043\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1042\n",
+      "Round: 11\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04023 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1044\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03841 Acc 0.85625 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1045\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03969 Acc 0.85000 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1046\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch   1 Batch   110 (  225) Loss 0.03921 Acc 0.85000 | Val acc 0.81418 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1047\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1044\n",
+      "Round: 12\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04475 Acc 0.81875 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1048\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03966 Acc 0.85625 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1049\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03139 Acc 0.90000 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1050\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.04275 Acc 0.81250 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1051\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1050\n",
+      "Round: 13\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03359 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1052\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03795 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1053\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03523 Acc 0.87500 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1054\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03383 Acc 0.89375 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1055\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1054\n",
+      "Round: 14\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.02968 Acc 0.89375 | Val acc 0.81418 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1056\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.03354 Acc 0.85625 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1057\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.02782 Acc 0.90000 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1058\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.02719 Acc 0.89375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1059\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1058\n",
+      "Round: 15\n",
+      "Switching to EMI-Loss function\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.44879 Acc 0.84375 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1060\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.42052 Acc 0.85000 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1061\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.37942 Acc 0.87500 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1062\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38046 Acc 0.88125 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1063\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1063\n",
+      "Round: 16\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38822 Acc 0.86875 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1064\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.37139 Acc 0.88750 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1065\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38185 Acc 0.88125 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1066\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.37084 Acc 0.86875 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1067\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1065\n",
+      "Round: 17\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.37820 Acc 0.89375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1068\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.36451 Acc 0.84375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1069\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.36738 Acc 0.89375 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1070\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.40073 Acc 0.88750 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1071\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1068\n",
+      "Round: 18\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38454 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1072\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.36183 Acc 0.90000 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1073\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38793 Acc 0.90000 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1074\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38047 Acc 0.83750 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1075\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1072\n",
+      "Round: 19\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38943 Acc 0.85625 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1076\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.40676 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1077\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.37826 Acc 0.90000 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1078\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.39518 Acc 0.86250 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1079\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1078\n",
+      "Round: 20\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38738 Acc 0.85625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1080\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.38468 Acc 0.85625 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1081\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.35865 Acc 0.89375 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1082\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.33802 Acc 0.92500 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1083\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1083\n",
+      "Round: 21\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.35882 Acc 0.89375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1084\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.37442 Acc 0.85625 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1085\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.40794 Acc 0.87500 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1086\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.33659 Acc 0.91250 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1087\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1084\n",
+      "Round: 22\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.33654 Acc 0.88125 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1088\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.36120 Acc 0.88125 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1089\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34685 Acc 0.88125 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1090\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32916 Acc 0.93125 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1091\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1091\n",
+      "Round: 23\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34161 Acc 0.90625 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1092\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.39510 Acc 0.85000 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1093\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch   1 Batch   110 (  225) Loss 0.37442 Acc 0.91875 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1094\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32392 Acc 0.91875 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1095\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1094\n",
+      "Round: 24\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32550 Acc 0.90625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1096\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34901 Acc 0.90625 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1097\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34062 Acc 0.90625 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1098\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.33095 Acc 0.90625 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1099\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1099\n",
+      "Round: 25\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34793 Acc 0.90000 | Val acc 0.85819 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1100\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32091 Acc 0.92500 | Val acc 0.86553 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1101\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.35738 Acc 0.90625 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1102\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34036 Acc 0.90000 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1103\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1101\n",
+      "Round: 26\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34306 Acc 0.88750 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1104\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32850 Acc 0.92500 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1105\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34266 Acc 0.88750 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1106\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.34165 Acc 0.89375 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1107\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1105\n",
+      "Round: 27\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32461 Acc 0.89375 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1108\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.33608 Acc 0.89375 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1109\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.33137 Acc 0.90625 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1110\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.28914 Acc 0.95000 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1111\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1110\n",
+      "Round: 28\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.35324 Acc 0.92500 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1112\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.31689 Acc 0.91875 | Val acc 0.86797 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1113\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32767 Acc 0.92500 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1114\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.30775 Acc 0.91875 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1115\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1113\n",
+      "Round: 29\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.33176 Acc 0.91250 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1116\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.31769 Acc 0.90625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1117\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.32659 Acc 0.90625 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1118\n",
+      "Epoch   1 Batch   110 (  225) Loss 0.31592 Acc 0.93125 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1119\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1118\n",
+      "Accuracy at k = 2: 0.863992\n",
+      "Savings due to MI-RNN : 0.600000\n",
+      "Savings due to Early prediction: 0.208659\n",
+      "Total Savings: 0.683464\n",
+      "   len       acc  macro-fsc  macro-pre  macro-rec  micro-fsc  micro-pre  \\\n",
+      "0    1  0.842466   0.842041   0.849946   0.844332   0.842466   0.842466   \n",
+      "1    2  0.863992   0.863982   0.865217   0.864776   0.863992   0.863992   \n",
+      "2    3  0.867906   0.867860   0.867807   0.867995   0.867906   0.867906   \n",
+      "3    4  0.864971   0.864697   0.865674   0.864385   0.864971   0.864971   \n",
+      "4    5  0.857143   0.856430   0.860480   0.855905   0.857143   0.857143   \n",
+      "\n",
+      "   micro-rec  fscore_01  \n",
+      "0   0.842466   0.850233  \n",
+      "1   0.863992   0.865179  \n",
+      "2   0.867906   0.865404  \n",
+      "3   0.864971   0.858607  \n",
+      "4   0.857143   0.846316  \n",
+      "Max accuracy 0.867906 at subsequencelength 3\n",
+      "Max micro-f 0.867906 at subsequencelength 3\n",
+      "Micro-precision 0.867906 at subsequencelength 3\n",
+      "Micro-recall 0.867906 at subsequencelength 3\n",
+      "Max macro-f 0.867860 at subsequencelength 3\n",
+      "macro-precision 0.867807 at subsequencelength 3\n",
+      "macro-recall 0.867995 at subsequencelength 3\n",
+      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
+      "|    |   len |      acc |   macro-fsc |   macro-pre |   macro-rec |   micro-fsc |   micro-pre |   micro-rec |   fscore_01 |\n",
+      "+====+=======+==========+=============+=============+=============+=============+=============+=============+=============+\n",
+      "|  0 |     1 | 0.842466 |    0.842041 |    0.849946 |    0.844332 |    0.842466 |    0.842466 |    0.842466 |    0.850233 |\n",
+      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
+      "|  1 |     2 | 0.863992 |    0.863982 |    0.865217 |    0.864776 |    0.863992 |    0.863992 |    0.863992 |    0.865179 |\n",
+      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
+      "|  2 |     3 | 0.867906 |    0.86786  |    0.867807 |    0.867995 |    0.867906 |    0.867906 |    0.867906 |    0.865404 |\n",
+      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
+      "|  3 |     4 | 0.864971 |    0.864697 |    0.865674 |    0.864385 |    0.864971 |    0.864971 |    0.864971 |    0.858607 |\n",
+      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
+      "|  4 |     5 | 0.857143 |    0.85643  |    0.86048  |    0.855905 |    0.857143 |    0.857143 |    0.857143 |    0.846316 |\n",
+      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1003\n",
+      "Round:  0, Validation accuracy: 0.6650, Test Accuracy (k = 2): 0.684932, Additional savings: 0.006531\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1007\n",
+      "Round:  1, Validation accuracy: 0.7066, Test Accuracy (k = 2): 0.687867, Additional savings: 0.008072\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1011\n",
+      "Round:  2, Validation accuracy: 0.7482, Test Accuracy (k = 2): 0.721135, Additional savings: 0.014555\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1015\n",
+      "Round:  3, Validation accuracy: 0.7824, Test Accuracy (k = 2): 0.745597, Additional savings: 0.020597\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1019\n",
+      "Round:  4, Validation accuracy: 0.7897, Test Accuracy (k = 2): 0.776908, Additional savings: 0.027740\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1022\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Round:  5, Validation accuracy: 0.8044, Test Accuracy (k = 2): 0.793542, Additional savings: 0.037769\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1027\n",
+      "Round:  6, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.817025, Additional savings: 0.049315\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1031\n",
+      "Round:  7, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.826810, Additional savings: 0.054795\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1032\n",
+      "Round:  8, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.825832, Additional savings: 0.060372\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1039\n",
+      "Round:  9, Validation accuracy: 0.8264, Test Accuracy (k = 2): 0.821918, Additional savings: 0.068249\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1042\n",
+      "Round: 10, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.820939, Additional savings: 0.074609\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1044\n",
+      "Round: 11, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.830724, Additional savings: 0.078914\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1050\n",
+      "Round: 12, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.835616, Additional savings: 0.090411\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1054\n",
+      "Round: 13, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.835616, Additional savings: 0.098973\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1058\n",
+      "Round: 14, Validation accuracy: 0.8460, Test Accuracy (k = 2): 0.829746, Additional savings: 0.109638\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1063\n",
+      "Round: 15, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.827789, Additional savings: 0.111913\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1065\n",
+      "Round: 16, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.820939, Additional savings: 0.128156\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1068\n",
+      "Round: 17, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.830724, Additional savings: 0.129525\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1072\n",
+      "Round: 18, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.829746, Additional savings: 0.134760\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1078\n",
+      "Round: 19, Validation accuracy: 0.8386, Test Accuracy (k = 2): 0.837573, Additional savings: 0.143102\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1083\n",
+      "Round: 20, Validation accuracy: 0.8435, Test Accuracy (k = 2): 0.835616, Additional savings: 0.161717\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1084\n",
+      "Round: 21, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.850294, Additional savings: 0.167392\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1091\n",
+      "Round: 22, Validation accuracy: 0.8509, Test Accuracy (k = 2): 0.846380, Additional savings: 0.177886\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1094\n",
+      "Round: 23, Validation accuracy: 0.8484, Test Accuracy (k = 2): 0.860078, Additional savings: 0.175612\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1099\n",
+      "Round: 24, Validation accuracy: 0.8460, Test Accuracy (k = 2): 0.851272, Additional savings: 0.189750\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1101\n",
+      "Round: 25, Validation accuracy: 0.8655, Test Accuracy (k = 2): 0.854207, Additional savings: 0.193395\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1105\n",
+      "Round: 26, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.853229, Additional savings: 0.197114\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1110\n",
+      "Round: 27, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.851272, Additional savings: 0.199731\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1113\n",
+      "Round: 28, Validation accuracy: 0.8680, Test Accuracy (k = 2): 0.860078, Additional savings: 0.208586\n",
+      "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1118\n",
+      "Round: 29, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.863992, Additional savings: 0.208659\n",
+      "Results for this run have been saved at lstm .\n"
+     ]
+    }
+   ],
+   "source": [
+    "## Baseline EMI-LSTM\n",
+    "\n",
+    "dataset = 'SWELL-KW'\n",
+    "path = '/home/sf/data/SWELL-KW/8_3/'\n",
+    "\n",
+    "#Choose model from among [lstm, fastgrnn, gru]\n",
+    "model = 'lstm'\n",
+    "\n",
+    "# Dictionary to set the parameters.\n",
+    "params = {\n",
+    "    \"NUM_HIDDEN\" : 128,\n",
+    "    \"NUM_TIMESTEPS\" : 8, #subinstance length.\n",
+    "    \"ORIGINAL_NUM_TIMESTEPS\" : 20,\n",
+    "    \"NUM_FEATS\" : 22,\n",
+    "    \"FORGET_BIAS\" : 1.0,\n",
+    "    \"NUM_OUTPUT\" : 2,\n",
+    "    \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
+    "    \"KEEP_PROB\" : 0.75,\n",
+    "    \"PREFETCH_NUM\" : 5,\n",
+    "    \"BATCH_SIZE\" : 32,\n",
+    "    \"NUM_EPOCHS\" : 2,\n",
+    "    \"NUM_ITER\" : 4,\n",
+    "    \"NUM_ROUNDS\" : 30,\n",
+    "    \"LEARNING_RATE\" : 0.001,\n",
+    "    \"FRAC_EMI\" : 0.5,\n",
+    "    \"MODEL_PREFIX\" : '/home/sf/data/SWELL-KW/models/model-lstm'\n",
+    "}\n",
+    "\n",
+    "#Preprocess data, and load the train,test and validation splits.\n",
+    "lstm_dict = lstm_experiment_generator(params, path)\n",
+    "\n",
+    "#Create the directory to store the results of this run.\n",
+    "\n",
+    "dirname = \"\" + model\n",
+    "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
+    "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
+    "\n",
+    "now = datetime.datetime.now()\n",
+    "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
+    "filename = ''.join(filename)\n",
+    "\n",
+    "#Save the dictionary containing the params and the results.\n",
+    "pkl.dump(lstm_dict,open(dirname + \"/lstm_dict_\" + filename + \".pkl\",mode='wb'))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}