[d395cf]: / DREAMER / DREAMER_Arousal_LSTM_64_16.ipynb

Download this file

1381 lines (1380 with data), 54.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DREAMER Arousal EMI-LSTM 64_16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-15T18:30:17.522073Z",
     "start_time": "2019-07-15T18:30:17.217355Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tabulate import tabulate\n",
    "import os\n",
    "import datetime as datetime\n",
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DataFrames from CSVs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-15T18:30:28.212447Z",
     "start_time": "2019-07-15T18:30:17.889941Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sf/.local/lib/python3.6/site-packages/numpy/lib/arraysetops.py:472: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  mask |= (ar1 == a)\n"
     ]
    }
   ],
   "source": [
    "df = pd.read_csv('/home/sf/data/DREAMER/DREAMER_combined.csv',index_col=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T16:01:09.440118Z",
     "start_time": "2019-06-18T16:01:09.394271Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',\n",
       "       '13', 'Person', 'Movie', 'Arousal', 'Dominance', 'Valence', 'ECG1',\n",
       "       'ECG2'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Split Ground Truth "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-15T18:30:37.760226Z",
     "start_time": "2019-07-15T18:30:37.597908Z"
    }
   },
   "outputs": [],
   "source": [
    "filtered_train = df.drop(['Movie', 'Person', 'Arousal','Dominance', 'Valence'], axis=1)\n",
    "filtered_target = df['Arousal']\n",
    "filtered_target = filtered_target.replace({1:0,2:1,3:2,4:3,5:4})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-15T18:32:27.581110Z",
     "start_time": "2019-07-15T18:32:27.576162Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10975232,)\n",
      "(10975232, 16)\n"
     ]
    }
   ],
   "source": [
    "print(filtered_target.shape)\n",
    "print(filtered_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:24:41.449331Z",
     "start_time": "2019-06-19T05:24:41.445741Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "y = filtered_target.values.reshape(85744, 128) # 128 is the size of 1 bag, 85744 = (size of the entire set) / 128"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert to 3D - (Bags, Timesteps, Features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(filtered_train.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:24:43.963421Z",
     "start_time": "2019-06-19T05:24:43.944207Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10975232, 16)\n",
      "(85744, 128, 16)\n"
     ]
    }
   ],
   "source": [
    "x = filtered_train.values\n",
    "print(x.shape)\n",
    "x = x.reshape(int(len(x) / 128), 128, 16)\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter Overlapping Bags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:19.299273Z",
     "start_time": "2019-06-19T05:25:18.952971Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n"
     ]
    }
   ],
   "source": [
    "# filtering bags that overlap with another class\n",
    "bags_to_remove = []\n",
    "for i in range(len(y)):\n",
    "    if len(set(y[i])) > 1:\n",
    "        bags_to_remove.append(i)\n",
    "print(bags_to_remove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:26.446153Z",
     "start_time": "2019-06-19T05:25:26.256245Z"
    }
   },
   "outputs": [],
   "source": [
    "x = np.delete(x, bags_to_remove, axis=0)\n",
    "y = np.delete(y, bags_to_remove, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:27.260726Z",
     "start_time": "2019-06-19T05:25:27.254474Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(85744, 128, 16)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:28.101475Z",
     "start_time": "2019-06-19T05:25:28.096009Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(85744, 128)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Categorical Representation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:50.094089Z",
     "start_time": "2019-06-19T05:25:49.746284Z"
    }
   },
   "outputs": [],
   "source": [
    "one_hot_list = []\n",
    "for i in range(len(y)):\n",
    "    one_hot_list.append(set(y[i]).pop())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:52.203633Z",
     "start_time": "2019-06-19T05:25:52.198467Z"
    }
   },
   "outputs": [],
   "source": [
    "categorical_y_ver = one_hot_list\n",
    "categorical_y_ver = np.array(categorical_y_ver)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:53.132006Z",
     "start_time": "2019-06-19T05:25:53.126314Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(85744,)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "categorical_y_ver.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:25:54.495163Z",
     "start_time": "2019-06-19T05:25:54.489349Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "128"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:26:08.021392Z",
     "start_time": "2019-06-19T05:26:08.017038Z"
    }
   },
   "outputs": [],
   "source": [
    "def one_hot(y, numOutput):\n",
    "    y = np.reshape(y, [-1])\n",
    "    ret = np.zeros([y.shape[0], numOutput])\n",
    "    for i, label in enumerate(y):\n",
    "        ret[i, label] = 1\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract 3D Normalized Data with Validation Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:26:08.931435Z",
     "start_time": "2019-06-19T05:26:08.927397Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "import pathlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:26:10.295822Z",
     "start_time": "2019-06-19T05:26:09.832723Z"
    }
   },
   "outputs": [],
   "source": [
    "x_train_val_combined, x_test, y_train_val_combined, y_test = train_test_split(x, categorical_y_ver, test_size=0.20, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:26:11.260084Z",
     "start_time": "2019-06-19T05:26:11.253714Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 3, 2, ..., 3, 3, 3])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T17:24:41.832337Z",
     "start_time": "2019-06-18T17:24:41.141678Z"
    }
   },
   "outputs": [],
   "source": [
    "extractedDir = '/home/sf/data/DREAMER/Arousal'\n",
    "timesteps = x_train_val_combined.shape[-2]\n",
    "feats = x_train_val_combined.shape[-1]\n",
    "\n",
    "trainSize = int(x_train_val_combined.shape[0]*0.9)\n",
    "x_train, x_val = x_train_val_combined[:trainSize], x_train_val_combined[trainSize:] \n",
    "y_train, y_val = y_train_val_combined[:trainSize], y_train_val_combined[trainSize:]\n",
    "\n",
    "# normalization\n",
    "x_train = np.reshape(x_train, [-1, feats])\n",
    "mean = np.mean(x_train, axis=0)\n",
    "std = np.std(x_train, axis=0)\n",
    "\n",
    "# normalize train\n",
    "x_train = x_train - mean\n",
    "x_train = x_train / std\n",
    "x_train = np.reshape(x_train, [-1, timesteps, feats])\n",
    "\n",
    "# normalize val\n",
    "x_val = np.reshape(x_val, [-1, feats])\n",
    "x_val = x_val - mean\n",
    "x_val = x_val / std\n",
    "x_val = np.reshape(x_val, [-1, timesteps, feats])\n",
    "\n",
    "# normalize test\n",
    "x_test = np.reshape(x_test, [-1, feats])\n",
    "x_test = x_test - mean\n",
    "x_test = x_test / std\n",
    "x_test = np.reshape(x_test, [-1, timesteps, feats])\n",
    "\n",
    "# shuffle test, as this was remaining\n",
    "idx = np.arange(len(x_test))\n",
    "np.random.shuffle(idx)\n",
    "x_test = x_test[idx]\n",
    "y_test = y_test[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T17:25:50.674962Z",
     "start_time": "2019-06-18T17:25:50.481068Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/sf/data/DREAMER/Arousal/\n"
     ]
    }
   ],
   "source": [
    "# one-hot encoding of labels\n",
    "numOutput = 5\n",
    "y_train = one_hot(y_train, numOutput)\n",
    "y_val = one_hot(y_val, numOutput)\n",
    "y_test = one_hot(y_test, numOutput)\n",
    "extractedDir += '/'\n",
    "\n",
    "pathlib.Path(extractedDir + 'RAW').mkdir(parents=True, exist_ok = True)\n",
    "\n",
    "np.save(extractedDir + \"RAW/x_train\", x_train)\n",
    "np.save(extractedDir + \"RAW/y_train\", y_train)\n",
    "np.save(extractedDir + \"RAW/x_test\", x_test)\n",
    "np.save(extractedDir + \"RAW/y_test\", y_test)\n",
    "np.save(extractedDir + \"RAW/x_val\", x_val)\n",
    "np.save(extractedDir + \"RAW/y_val\", y_val)\n",
    "\n",
    "print(extractedDir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T17:26:35.381650Z",
     "start_time": "2019-06-18T17:26:35.136645Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_test.npy  x_train.npy  x_val.npy  y_test.npy  y_train.npy  y_val.npy\r\n"
     ]
    }
   ],
   "source": [
    "ls /home/sf/data/DREAMER/Arousal/RAW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T17:27:34.458130Z",
     "start_time": "2019-06-18T17:27:34.323712Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(61735, 128, 16)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.load('/home/sf/data/DREAMER/Arousal/RAW/x_train.npy').shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make 4D EMI Data (Bags, Subinstances, Subinstance Length, Features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:30:53.720179Z",
     "start_time": "2019-06-19T05:30:53.713756Z"
    }
   },
   "outputs": [],
   "source": [
    "def loadData(dirname):\n",
    "    x_train = np.load(dirname + '/' + 'x_train.npy')\n",
    "    y_train = np.load(dirname + '/' + 'y_train.npy')\n",
    "    x_test = np.load(dirname + '/' + 'x_test.npy')\n",
    "    y_test = np.load(dirname + '/' + 'y_test.npy')\n",
    "    x_val = np.load(dirname + '/' + 'x_val.npy')\n",
    "    y_val = np.load(dirname + '/' + 'y_val.npy')\n",
    "    return x_train, y_train, x_test, y_test, x_val, y_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T17:29:23.614052Z",
     "start_time": "2019-06-18T17:29:23.601324Z"
    }
   },
   "outputs": [],
   "source": [
    "def bagData(X, Y, subinstanceLen, subinstanceStride):\n",
    "    numClass = 5\n",
    "    numSteps = 128\n",
    "    numFeats = 16\n",
    "    assert X.ndim == 3\n",
    "    assert X.shape[1] == numSteps\n",
    "    assert X.shape[2] == numFeats\n",
    "    assert subinstanceLen <= numSteps\n",
    "    assert subinstanceLen > 0\n",
    "    assert subinstanceStride <= numSteps\n",
    "    assert subinstanceStride >= 0\n",
    "    assert len(X) == len(Y)\n",
    "    assert Y.ndim == 2\n",
    "    assert Y.shape[1] == numClass\n",
    "    x_bagged = []\n",
    "    y_bagged = []\n",
    "    for i, point in enumerate(X[:, :, :]):\n",
    "        instanceList = []\n",
    "        start = 0\n",
    "        end = subinstanceLen\n",
    "        while True:\n",
    "            x = point[start:end, :]\n",
    "            if len(x) < subinstanceLen:\n",
    "                x_ = np.zeros([subinstanceLen, x.shape[1]])\n",
    "                x_[:len(x), :] = x[:, :]\n",
    "                x = x_\n",
    "            instanceList.append(x)\n",
    "            if end >= numSteps:\n",
    "                break\n",
    "            start += subinstanceStride\n",
    "            end += subinstanceStride\n",
    "        bag = np.array(instanceList)\n",
    "        numSubinstance = bag.shape[0]\n",
    "        label = Y[i]\n",
    "        label = np.argmax(label)\n",
    "        labelBag = np.zeros([numSubinstance, numClass])\n",
    "        labelBag[:, label] = 1\n",
    "        x_bagged.append(bag)\n",
    "        label = np.array(labelBag)\n",
    "        y_bagged.append(label)\n",
    "    return np.array(x_bagged), np.array(y_bagged)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:33:06.531994Z",
     "start_time": "2019-06-19T05:33:06.523884Z"
    }
   },
   "outputs": [],
   "source": [
    "def makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir):\n",
    "    x_train, y_train, x_test, y_test, x_val, y_val = loadData(sourceDir)\n",
    "    x, y = bagData(x_train, y_train, subinstanceLen, subinstanceStride)\n",
    "    np.save(outDir + '/x_train.npy', x)\n",
    "    np.save(outDir + '/y_train.npy', y)\n",
    "    print('Num train %d' % len(x))\n",
    "    x, y = bagData(x_test, y_test, subinstanceLen, subinstanceStride)\n",
    "    np.save(outDir + '/x_test.npy', x)\n",
    "    np.save(outDir + '/y_test.npy', y)\n",
    "    print('Num test %d' % len(x))\n",
    "    x, y = bagData(x_val, y_val, subinstanceLen, subinstanceStride)\n",
    "    np.save(outDir + '/x_val.npy', x)\n",
    "    np.save(outDir + '/y_val.npy', y)\n",
    "    print('Num val %d' % len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:35:20.960014Z",
     "start_time": "2019-06-19T05:35:20.050363Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num train 61735\n",
      "Num test 17149\n",
      "Num val 6860\n"
     ]
    }
   ],
   "source": [
    "subinstanceLen = 64\n",
    "subinstanceStride = 16\n",
    "extractedDir = '/home/sf/data/DREAMER/Arousal'\n",
    "from os import mkdir\n",
    "mkdir('/home/sf/data/DREAMER/Arousal' + '/%d_%d/' % (subinstanceLen, subinstanceStride))\n",
    "rawDir = extractedDir + '/RAW'\n",
    "sourceDir = rawDir\n",
    "outDir = extractedDir + '/%d_%d/' % (subinstanceLen, subinstanceStride)\n",
    "makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T17:48:58.293843Z",
     "start_time": "2019-06-18T17:48:58.285383Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(61735, 5, 5)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.load('/home/sf/data/DREAMER/Arousal/64_16/y_train.npy').shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-18T12:49:10.278023Z",
     "start_time": "2019-06-18T12:49:10.273245Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(61735, 128, 16)\n",
      "(61735, 5)\n"
     ]
    }
   ],
   "source": [
    "print(x_train.shape)\n",
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:35:48.609552Z",
     "start_time": "2019-06-19T05:35:48.604291Z"
    }
   },
   "outputs": [],
   "source": [
    "from edgeml.graph.rnn import EMI_DataPipeline\n",
    "from edgeml.graph.rnn import EMI_BasicLSTM, EMI_FastGRNN, EMI_FastRNN, EMI_GRU\n",
    "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
    "import edgeml.utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:43:48.032413Z",
     "start_time": "2019-06-19T05:43:47.987839Z"
    }
   },
   "outputs": [],
   "source": [
    "def lstm_experiment_generator(params, path = './DSAAR/64_16/'):\n",
    "    \"\"\"\n",
    "        Function that will generate the experiments to be run.\n",
    "        Inputs : \n",
    "        (1) Dictionary params, to set the network parameters.\n",
    "        (2) Name of the Model to be run from [EMI-LSTM, EMI-FastGRNN, EMI-GRU]\n",
    "        (3) Path to the dataset, where the csv files are present.\n",
    "    \"\"\"\n",
    "    \n",
    "    #Copy the contents of the params dictionary.\n",
    "    lstm_dict = {**params}\n",
    "    \n",
    "    #---------------------------PARAM SETTING----------------------#\n",
    "    \n",
    "    # Network parameters for our LSTM + FC Layer\n",
    "    NUM_HIDDEN = params[\"NUM_HIDDEN\"]\n",
    "    NUM_TIMESTEPS = params[\"NUM_TIMESTEPS\"]\n",
    "    ORIGINAL_NUM_TIMESTEPS = params[\"ORIGINAL_NUM_TIMESTEPS\"]\n",
    "    NUM_FEATS = params[\"NUM_FEATS\"]\n",
    "    FORGET_BIAS = params[\"FORGET_BIAS\"]\n",
    "    NUM_OUTPUT = params[\"NUM_OUTPUT\"]\n",
    "    USE_DROPOUT = True if (params[\"USE_DROPOUT\"] == 1) else False\n",
    "    KEEP_PROB = params[\"KEEP_PROB\"]\n",
    "\n",
    "    # For dataset API\n",
    "    PREFETCH_NUM = params[\"PREFETCH_NUM\"]\n",
    "    BATCH_SIZE = params[\"BATCH_SIZE\"]\n",
    "\n",
    "    # Number of epochs in *one iteration*\n",
    "    NUM_EPOCHS = params[\"NUM_EPOCHS\"]\n",
    "    # Number of iterations in *one round*. After each iteration,\n",
    "    # the model is dumped to disk. At the end of the current\n",
    "    # round, the best model among all the dumped models in the\n",
    "    # current round is picked up..\n",
    "    NUM_ITER = params[\"NUM_ITER\"]\n",
    "    # A round consists of multiple training iterations and a belief\n",
    "    # update step using the best model from all of these iterations\n",
    "    NUM_ROUNDS = params[\"NUM_ROUNDS\"]\n",
    "    LEARNING_RATE = params[\"LEARNING_RATE\"]\n",
    "\n",
    "    # A staging direcory to store models\n",
    "    MODEL_PREFIX = params[\"MODEL_PREFIX\"]\n",
    "    \n",
    "    #----------------------END OF PARAM SETTING----------------------#\n",
    "    \n",
    "    #----------------------DATA LOADING------------------------------#\n",
    "    \n",
    "    x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
    "    x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
    "    x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
    "\n",
    "    # BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
    "    # step of EMI/MI RNN\n",
    "    BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
    "    BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
    "    BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
    "    NUM_SUBINSTANCE = x_train.shape[1]\n",
    "    print(\"x_train shape is:\", x_train.shape)\n",
    "    print(\"y_train shape is:\", y_train.shape)\n",
    "    print(\"x_test shape is:\", x_val.shape)\n",
    "    print(\"y_test shape is:\", y_val.shape)\n",
    "    \n",
    "    #----------------------END OF DATA LOADING------------------------------#    \n",
    "    \n",
    "    #----------------------COMPUTATION GRAPH--------------------------------#\n",
    "    \n",
    "    # Define the linear secondary classifier\n",
    "    def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
    "        W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
    "        B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
    "        y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
    "        self.output = y_cap\n",
    "        self.graphCreated = True\n",
    "\n",
    "    def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
    "        y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
    "        self.output = y_cap\n",
    "        self.graphCreated = True\n",
    "\n",
    "    def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
    "        if inference is False:\n",
    "            feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
    "        else:\n",
    "            feedDict = {self._emiGraph.keep_prob: 1.0}\n",
    "        return feedDict\n",
    "\n",
    "    EMI_BasicLSTM._createExtendedGraph = createExtendedGraph\n",
    "    EMI_BasicLSTM._restoreExtendedGraph = restoreExtendedGraph\n",
    "\n",
    "    if USE_DROPOUT is True:\n",
    "        EMI_Driver.feedDictFunc = feedDictFunc\n",
    "    \n",
    "    inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
    "    emiLSTM = EMI_BasicLSTM(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
    "                            forgetBias=FORGET_BIAS, useDropout=USE_DROPOUT)\n",
    "    emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
    "                             stepSize=LEARNING_RATE)\n",
    "    \n",
    "    tf.reset_default_graph()\n",
    "    g1 = tf.Graph()    \n",
    "    with g1.as_default():\n",
    "        # Obtain the iterators to each batch of the data\n",
    "        x_batch, y_batch = inputPipeline()\n",
    "        # Create the forward computation graph based on the iterators\n",
    "        y_cap = emiLSTM(x_batch)\n",
    "        # Create loss graphs and training routines\n",
    "        emiTrainer(y_cap, y_batch)\n",
    "        \n",
    "    #------------------------------END OF COMPUTATION GRAPH------------------------------#\n",
    "    \n",
    "    #-------------------------------------EMI DRIVER-------------------------------------#\n",
    "        \n",
    "    with g1.as_default():\n",
    "        emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n",
    "\n",
    "    emiDriver.initializeSession(g1)\n",
    "    y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
    "                                          y_train=y_train, bag_train=BAG_TRAIN,\n",
    "                                          x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
    "                                          numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
    "                                          numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
    "                                          numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
    "                                          fracEMI=0.5, updatePolicy='top-k', k=1)\n",
    "    \n",
    "    #-------------------------------END OF EMI DRIVER-------------------------------------#\n",
    "    \n",
    "    #-----------------------------------EARLY SAVINGS-------------------------------------#\n",
    "    \n",
    "    \"\"\"\n",
    "        Early Prediction Policy: We make an early prediction based on the predicted classes\n",
    "        probability. If the predicted class probability > minProb at some step, we make\n",
    "        a prediction at that step.\n",
    "    \"\"\"\n",
    "    def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
    "        assert instanceOut.ndim == 2\n",
    "        classes = np.argmax(instanceOut, axis=1)\n",
    "        prob = np.max(instanceOut, axis=1)\n",
    "        index = np.where(prob >= minProb)[0]\n",
    "        if len(index) == 0:\n",
    "            assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
    "            return classes[-1], len(instanceOut) - 1\n",
    "        index = index[0]\n",
    "        return classes[index], index\n",
    "\n",
    "    def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
    "        predictionStep = predictionStep + 1\n",
    "        predictionStep = np.reshape(predictionStep, -1)\n",
    "        totalSteps = np.sum(predictionStep)\n",
    "        maxSteps = len(predictionStep) * numTimeSteps\n",
    "        savings = 1.0 - (totalSteps / maxSteps)\n",
    "        if returnTotal:\n",
    "            return savings, totalSteps\n",
    "        return savings\n",
    "    \n",
    "    #--------------------------------END OF EARLY SAVINGS---------------------------------#\n",
    "    \n",
    "    #----------------------------------------BEST MODEL-----------------------------------#\n",
    "    \n",
    "    k = 2\n",
    "    predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
    "                                                                   minProb=0.99, keep_prob=1.0)\n",
    "    bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
    "    print('Accuracy at k = %d: %f' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
    "    mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
    "    emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
    "    total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
    "    print('Savings due to MI-RNN : %f' % mi_savings)\n",
    "    print('Savings due to Early prediction: %f' % emi_savings)\n",
    "    print('Total Savings: %f' % (total_savings))\n",
    "    \n",
    "    #Store in the dictionary.\n",
    "    lstm_dict[\"k\"] = k\n",
    "    lstm_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
    "    lstm_dict[\"total_savings\"] = total_savings\n",
    "    lstm_dict[\"y_test\"] = BAG_TEST\n",
    "    lstm_dict[\"y_pred\"] = bagPredictions\n",
    "    \n",
    "    # A slightly more detailed analysis method is provided. \n",
    "    df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
    "    print (tabulate(df, headers=list(df.columns), tablefmt='grid'))\n",
    "    \n",
    "    lstm_dict[\"detailed analysis\"] = df\n",
    "    #----------------------------------END OF BEST MODEL-----------------------------------#\n",
    "    \n",
    "    #----------------------------------PICKING THE BEST MODEL------------------------------#\n",
    "    \n",
    "    devnull = open(os.devnull, 'r')\n",
    "    for val in modelStats:\n",
    "        round_, acc, modelPrefix, globalStep = val\n",
    "        emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
    "        predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
    "                                                                   minProb=0.99, keep_prob=1.0)\n",
    "\n",
    "        bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
    "        print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
    "        print(', Test Accuracy (k = %d): %f, ' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
    "        print('Additional savings: %f' % getEarlySaving(predictionStep, NUM_TIMESTEPS)) \n",
    "        \n",
    "    \n",
    "    #-------------------------------END OF PICKING THE BEST MODEL--------------------------#\n",
    "\n",
    "    return lstm_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:43:51.798834Z",
     "start_time": "2019-06-19T05:43:51.792938Z"
    }
   },
   "outputs": [],
   "source": [
    "def experiment_generator(params, path, model = 'lstm'):\n",
    "    \n",
    "    \n",
    "    if (model == 'lstm'): return lstm_experiment_generator(params, path)\n",
    "    elif (model == 'fastgrnn'): return fastgrnn_experiment_generator(params, path)\n",
    "    elif (model == 'gru'): return gru_experiment_generator(params, path)\n",
    "    elif (model == 'baseline'): return baseline_experiment_generator(params, path)\n",
    "    \n",
    "    return "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T05:43:53.168413Z",
     "start_time": "2019-06-19T05:43:53.164622Z"
    }
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-19T08:16:13.770690Z",
     "start_time": "2019-06-19T05:45:31.777404Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0811 06:07:09.419061 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1141: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "W0811 06:07:09.438090 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1153: The name tf.data.Iterator is deprecated. Please use tf.compat.v1.data.Iterator instead.\n",
      "\n",
      "W0811 06:07:09.439491 139879956309824 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1153: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_types(dataset)`.\n",
      "W0811 06:07:09.440599 139879956309824 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1154: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_shapes(dataset)`.\n",
      "W0811 06:07:09.446861 139879956309824 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:348: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_types(iterator)`.\n",
      "W0811 06:07:09.448351 139879956309824 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:349: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_shapes(iterator)`.\n",
      "W0811 06:07:09.449451 139879956309824 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:351: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_classes(iterator)`.\n",
      "W0811 06:07:09.453782 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1159: The name tf.add_to_collection is deprecated. Please use tf.compat.v1.add_to_collection instead.\n",
      "\n",
      "W0811 06:07:09.459037 139879956309824 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1396: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape is: (61735, 5, 64, 16)\n",
      "y_train shape is: (61735, 5, 5)\n",
      "x_test shape is: (6860, 5, 64, 16)\n",
      "y_test shape is: (6860, 5, 5)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0811 06:07:09.883782 139879956309824 lazy_loader.py:50] \n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "W0811 06:07:09.887611 139879956309824 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1404: static_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell, unroll=True)`, which is equivalent to this API\n",
      "W0811 06:07:10.070225 139879956309824 deprecation.py:506] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0811 06:07:10.082360 139879956309824 deprecation.py:506] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py:738: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0811 06:07:13.084146 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:135: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.\n",
      "\n",
      "W0811 06:07:13.117974 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:162: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n",
      "W0811 06:07:19.247013 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:329: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.\n",
      "\n",
      "W0811 06:07:19.271834 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:369: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n",
      "W0811 06:07:19.286889 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:372: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Update policy: top-k\n",
      "Training with MI-RNN loss for 5 rounds\n",
      "Round: 0\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.02072 Acc 0.46875 | Val acc 0.40466 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1000\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01906 Acc 0.52500 | Val acc 0.45525 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1001\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01898 Acc 0.50000 | Val acc 0.48324 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1002\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01789 Acc 0.56875 | Val acc 0.52522 | "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0811 06:40:46.006091 139879956309824 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/utils.py:335: The name tf.train.import_meta_graph is deprecated. Please use tf.compat.v1.train.import_meta_graph instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved to DREAMER_AROUSAL/model-lstm, global_step 1003\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0811 06:40:48.303325 139879956309824 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Round: 1\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01708 Acc 0.61250 | Val acc 0.54213 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1004\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01477 Acc 0.68125 | Val acc 0.56735 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1005\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01506 Acc 0.66250 | Val acc 0.59446 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1006\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01383 Acc 0.68750 | Val acc 0.60233 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1007\n",
      "Round: 2\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01338 Acc 0.70000 | Val acc 0.61501 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1008\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01250 Acc 0.70625 | Val acc 0.62478 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1009\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01215 Acc 0.74375 | Val acc 0.63673 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1010\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01186 Acc 0.73750 | Val acc 0.64052 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1011\n",
      "Round: 3\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01130 Acc 0.78125 | Val acc 0.64519 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1012\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01183 Acc 0.74375 | Val acc 0.64942 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1013\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01149 Acc 0.74375 | Val acc 0.65204 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1014\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01207 Acc 0.75000 | Val acc 0.66137 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1015\n",
      "Round: 4\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01071 Acc 0.78750 | Val acc 0.65335 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1016\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01013 Acc 0.80625 | Val acc 0.66064 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1017\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01143 Acc 0.71250 | Val acc 0.66297 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1018\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.01101 Acc 0.79375 | Val acc 0.66676 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1019\n",
      "Round: 5\n",
      "Switching to EMI-Loss function\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.89245 Acc 0.76875 | Val acc 0.65306 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1020\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.84946 Acc 0.80625 | Val acc 0.66283 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1021\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.88004 Acc 0.74375 | Val acc 0.65831 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1022\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.84091 Acc 0.78750 | Val acc 0.66735 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1023\n",
      "Round: 6\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.85038 Acc 0.74375 | Val acc 0.66224 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1024\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.83785 Acc 0.72500 | Val acc 0.66647 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1025\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.81184 Acc 0.79375 | Val acc 0.67070 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1026\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.81252 Acc 0.73125 | Val acc 0.66968 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1027\n",
      "Round: 7\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.81934 Acc 0.78125 | Val acc 0.67478 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1028\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.82950 Acc 0.79375 | Val acc 0.67259 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1029\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.80113 Acc 0.76250 | Val acc 0.66808 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1030\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.80215 Acc 0.77500 | Val acc 0.66224 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1031\n",
      "Round: 8\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.80325 Acc 0.77500 | Val acc 0.66283 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1032\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.83236 Acc 0.75625 | Val acc 0.66691 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1033\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.76857 Acc 0.76875 | Val acc 0.66720 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1034\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.78972 Acc 0.77500 | Val acc 0.66837 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1035\n",
      "Round: 9\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.75360 Acc 0.79375 | Val acc 0.66924 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1036\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.77944 Acc 0.78125 | Val acc 0.67230 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1037\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.80270 Acc 0.75625 | Val acc 0.66866 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1038\n",
      "Epoch   1 Batch  1925 ( 3855) Loss 0.77377 Acc 0.76250 | Val acc 0.66181 | Model saved to DREAMER_AROUSAL/model-lstm, global_step 1039\n",
      "Accuracy at k = 2: 0.675725\n",
      "Savings due to MI-RNN : 0.500000\n",
      "Savings due to Early prediction: 0.222236\n",
      "Total Savings: 0.611118\n",
      "   len       acc  macro-fsc  macro-pre  macro-rec  micro-fsc  micro-pre  \\\n",
      "0    1  0.676249   0.645304   0.712381   0.620041   0.676249   0.676249   \n",
      "1    2  0.675725   0.643113   0.642737   0.646441   0.675725   0.675725   \n",
      "2    3  0.629658   0.579242   0.603195   0.630884   0.629658   0.629658   \n",
      "3    4  0.563240   0.541251   0.644912   0.593320   0.563240   0.563240   \n",
      "4    5  0.506385   0.508197   0.682197   0.555354   0.506385   0.506385   \n",
      "\n",
      "   micro-rec  \n",
      "0   0.676249  \n",
      "1   0.675725  \n",
      "2   0.629658  \n",
      "3   0.563240  \n",
      "4   0.506385  \n",
      "Max accuracy 0.676249 at subsequencelength 1\n",
      "Max micro-f 0.676249 at subsequencelength 1\n",
      "Micro-precision 0.676249 at subsequencelength 1\n",
      "Micro-recall 0.676249 at subsequencelength 1\n",
      "Max macro-f 0.645304 at subsequencelength 1\n",
      "macro-precision 0.712381 at subsequencelength 1\n",
      "macro-recall 0.620041 at subsequencelength 1\n",
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
      "|    |   len |      acc |   macro-fsc |   macro-pre |   macro-rec |   micro-fsc |   micro-pre |   micro-rec |\n",
      "+====+=======+==========+=============+=============+=============+=============+=============+=============+\n",
      "|  0 |     1 | 0.676249 |    0.645304 |    0.712381 |    0.620041 |    0.676249 |    0.676249 |    0.676249 |\n",
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
      "|  1 |     2 | 0.675725 |    0.643113 |    0.642737 |    0.646441 |    0.675725 |    0.675725 |    0.675725 |\n",
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
      "|  2 |     3 | 0.629658 |    0.579242 |    0.603195 |    0.630884 |    0.629658 |    0.629658 |    0.629658 |\n",
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
      "|  3 |     4 | 0.56324  |    0.541251 |    0.644912 |    0.59332  |    0.56324  |    0.56324  |    0.56324  |\n",
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
      "|  4 |     5 | 0.506385 |    0.508197 |    0.682197 |    0.555354 |    0.506385 |    0.506385 |    0.506385 |\n",
      "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'os' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-36-480e86a78fa5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;31m#Preprocess data, and load the train,test and validation splits.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mlstm_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlstm_experiment_generator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0;31m#Create the directory to store the results of this run.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-33-f0d934a26c49>\u001b[0m in \u001b[0;36mlstm_experiment_generator\u001b[0;34m(params, path)\u001b[0m\n\u001b[1;32m    185\u001b[0m     \u001b[0;31m#----------------------------------PICKING THE BEST MODEL------------------------------#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m     \u001b[0mdevnull\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevnull\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'r'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mval\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodelStats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0mround_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodelPrefix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglobalStep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'os' is not defined"
     ]
    }
   ],
   "source": [
    "# Baseline EMI-LSTM\n",
    "\n",
    "dataset = 'DREAMER_AROUSAL'\n",
    "path = '/home/sf/data/DREAMER/Arousal'+ '/%d_%d/' % (subinstanceLen, subinstanceStride)\n",
    "\n",
    "#Choose model from among [lstm, fastgrnn, gru]\n",
    "model = 'lstm'\n",
    "\n",
    "# Dictionary to set the parameters.\n",
    "params = {\n",
    "    \"NUM_HIDDEN\" : 128,\n",
    "    \"NUM_TIMESTEPS\" : 64, #subinstance length.\n",
    "    \"ORIGINAL_NUM_TIMESTEPS\" : 128,\n",
    "    \"NUM_FEATS\" : 16,\n",
    "    \"FORGET_BIAS\" : 1.0,\n",
    "    \"NUM_OUTPUT\" : 5,\n",
    "    \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
    "    \"KEEP_PROB\" : 0.75,\n",
    "    \"PREFETCH_NUM\" : 5,\n",
    "    \"BATCH_SIZE\" : 32,\n",
    "    \"NUM_EPOCHS\" : 2,\n",
    "    \"NUM_ITER\" : 4,\n",
    "    \"NUM_ROUNDS\" : 10,\n",
    "    \"LEARNING_RATE\" : 0.001,\n",
    "    \"FRAC_EMI\" : 0.5,\n",
    "    \"MODEL_PREFIX\" : dataset + '/model-' + str(model)\n",
    "}\n",
    "\n",
    "#Preprocess data, and load the train,test and validation splits.\n",
    "lstm_dict = lstm_experiment_generator(params, path)\n",
    "\n",
    "#Create the directory to store the results of this run.\n",
    "\n",
    "# dirname = \"\"\n",
    "# dirname = \"./Results\" + ''.join(dirname) + \"/\"+dataset+\"/\"+model\n",
    "# pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
    "# print (\"Results for this run have been saved at\" , dirname, \".\")\n",
    "/\n",
    "dirname = \"\" + model\n",
    "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
    "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
    "\n",
    "now = datetime.datetime.now()\n",
    "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
    "filename = ''.join(filename)\n",
    "\n",
    "#Save the dictionary containing the params and the results.\n",
    "pkl.dump(lstm_dict,open(dirname + \"/lstm_dict_\" + filename + \".pkl\",mode='wb'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}