[d395cf]: / WESAD / WESAD_Inference.ipynb

Download this file

1074 lines (1073 with data), 51.8 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initializing and Restoring EMI-RNN Graphs\n",
    "\n",
    "The *EMI-RNN* implementation supports four forms of initialization/restoring:\n",
    "1. An entirely new graph can be constructed with randomly initialized weights.\n",
    "2. A saved graph can be loaded into the current `EMI_Driver`.\n",
    "2. An entirely new graph can be constructed with weights initialized from a saved graph. This behavior is essentially restoration of a saved graph.\n",
    "3. (*Experimental*) Initializing/Restoring using numpy matrices. \n",
    "\n",
    "All three methods will be illustrated in this notebook. This notebook uses the HAR dataset and builds on the [EMI LSTM example.ipynb](00_emi_lstm_example.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:47:35.576928Z",
     "start_time": "2018-08-19T11:47:34.670184Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import os\n",
    "import sys\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "# To include edgeml in python path\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
    "\n",
    "# MI-RNN and EMI-RNN imports\n",
    "from edgeml.graph.rnn import EMI_DataPipeline\n",
    "from edgeml.graph.rnn import EMI_BasicLSTM\n",
    "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
    "import edgeml.utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us set up some network parameters for the computation graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:47:35.590781Z",
     "start_time": "2018-08-19T11:47:35.578914Z"
    }
   },
   "outputs": [],
   "source": [
    "# Network parameters for our LSTM + FC Layer\n",
    "NUM_HIDDEN = 128\n",
    "NUM_TIMESTEPS = 64\n",
    "NUM_FEATS = 16\n",
    "FORGET_BIAS = 1.0\n",
    "NUM_OUTPUT = 5\n",
    "USE_DROPOUT = True\n",
    "KEEP_PROB = 0.75\n",
    "\n",
    "# For dataset API\n",
    "PREFETCH_NUM = 5\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "# Number of epochs in *one iteration*\n",
    "NUM_EPOCHS = 3\n",
    "# Number of iterations in *one round*. After each iteration,\n",
    "# the model is dumped to disk. At the end of the current\n",
    "# round, the best model among all the dumped models in the\n",
    "# current round is picked up..\n",
    "NUM_ITER = 4\n",
    "# A round consists of multiple training iterations and a belief\n",
    "# update step using the best model from all of these iterations\n",
    "NUM_ROUNDS = 10\n",
    "\n",
    "# A staging direcory to store models\n",
    "dataset = 'DREAMER'\n",
    "MODEL_PREFIX = '/home/sf/data/EdgeML/tf/examples/EMI-RNN' + dataset + '/model-lstm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Network parameters for our LSTM + FC Layer\n",
    "NUM_HIDDEN = 128\n",
    "NUM_TIMESTEPS = 350\n",
    "NUM_FEATS = 8\n",
    "FORGET_BIAS = 1.0\n",
    "NUM_OUTPUT = 5\n",
    "USE_DROPOUT = True\n",
    "KEEP_PROB = 0.75\n",
    "\n",
    "# For dataset API\n",
    "PREFETCH_NUM = 5\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "# Number of epochs in *one iteration*\n",
    "NUM_EPOCHS = 3\n",
    "# Number of iterations in *one round*. After each iteration,\n",
    "# the model is dumped to disk. At the end of the current\n",
    "# round, the best model among all the dumped models in the\n",
    "# current round is picked up..\n",
    "NUM_ITER = 4\n",
    "# A round consists of multiple training iterations and a belief\n",
    "# update step using the best model from all of these iterations\n",
    "NUM_ROUNDS = 10\n",
    "\n",
    "# A staging direcory to store models\n",
    "dataset = 'WESAD'\n",
    "MODEL_PREFIX = '/home/sf/data/EdgeML/tf/examples/EMI-RNN' + dataset + '/model-lstm'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:47:35.694831Z",
     "start_time": "2018-08-19T11:47:35.592516Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape is: (61735, 5, 64, 16)\n",
      "y_train shape is: (61735, 5, 5)\n",
      "x_test shape is: (6860, 5, 64, 16)\n",
      "y_test shape is: (6860, 5, 5)\n"
     ]
    }
   ],
   "source": [
    "path=\"/home/sf/data/DREAMER/Dominance/64_16/\"\n",
    "x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
    "x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
    "x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
    "\n",
    "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
    "# step of EMI/MI RNN\n",
    "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
    "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
    "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
    "NUM_SUBINSTANCE = x_train.shape[1]\n",
    "print(\"x_train shape is:\", x_train.shape)\n",
    "print(\"y_train shape is:\", y_train.shape)\n",
    "print(\"x_test shape is:\", x_val.shape)\n",
    "print(\"y_test shape is:\", y_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_test shape is: (16863, 5, 350, 8)\n",
      "y_test shape is: (16863, 5, 5)\n"
     ]
    }
   ],
   "source": [
    "path=\"/home/sf/data/WESAD/350_116/\"\n",
    "#x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
    "x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
    "#x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
    "\n",
    "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
    "# step of EMI/MI RNN\n",
    "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
    "#BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
    "#BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
    "NUM_SUBINSTANCE = x_test.shape[1]\n",
    "#print(\"x_train shape is:\", x_train.shape)\n",
    "#print(\"y_train shape is:\", y_train.shape)\n",
    "print(\"x_test shape is:\", x_test.shape)\n",
    "print(\"y_test shape is:\", y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:47:35.739003Z",
     "start_time": "2018-08-19T11:47:35.696723Z"
    }
   },
   "outputs": [],
   "source": [
    "# Define the linear secondary classifier\n",
    "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
    "    W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
    "    B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
    "    y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
    "    self.output = y_cap\n",
    "    self.graphCreated = True\n",
    "    \n",
    "def addExtendedAssignOps(self, graph, W_val=None, B_val=None):\n",
    "    W1 = graph.get_tensor_by_name('W1:0')\n",
    "    B1 = graph.get_tensor_by_name('B1:0')\n",
    "    W1_op = tf.assign(W1, W_val)\n",
    "    B1_op = tf.assign(B1, B_val)\n",
    "    self.assignOps.extend([W1_op, B1_op])\n",
    "\n",
    "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
    "    y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
    "    self.output = y_cap\n",
    "    self.graphCreated = True\n",
    "    \n",
    "def feedDictFunc(self, keep_prob, **kwargs):\n",
    "    feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
    "    return feedDict\n",
    "    \n",
    "EMI_BasicLSTM._createExtendedGraph = createExtendedGraph\n",
    "EMI_BasicLSTM._restoreExtendedGraph = restoreExtendedGraph\n",
    "EMI_BasicLSTM.addExtendedAssignOps = addExtendedAssignOps\n",
    "\n",
    "if USE_DROPOUT is True:\n",
    "    EMI_Driver.feedDictFunc = feedDictFunc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T09:34:06.288012Z",
     "start_time": "2018-08-19T09:34:06.285286Z"
    }
   },
   "source": [
    "## 1. Initializing a New Computation Graph\n",
    "\n",
    "In the most common use cases, a new EMI-RNN graph would be created and trained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:49:55.326002Z",
     "start_time": "2018-08-19T11:49:50.568621Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0809 11:12:32.864859 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1141: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "W0809 11:12:32.881156 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1153: The name tf.data.Iterator is deprecated. Please use tf.compat.v1.data.Iterator instead.\n",
      "\n",
      "W0809 11:12:32.882301 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1153: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_types(dataset)`.\n",
      "W0809 11:12:32.883116 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1154: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_shapes(dataset)`.\n",
      "W0809 11:12:32.887950 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:348: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_types(iterator)`.\n",
      "W0809 11:12:32.888927 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:349: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_shapes(iterator)`.\n",
      "W0809 11:12:32.889680 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:351: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.data.get_output_classes(iterator)`.\n",
      "W0809 11:12:32.892827 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1159: The name tf.add_to_collection is deprecated. Please use tf.compat.v1.add_to_collection instead.\n",
      "\n",
      "W0809 11:12:32.899183 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1396: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
      "W0809 11:12:33.587531 140177766897472 lazy_loader.py:50] \n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "W0809 11:12:33.592689 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1404: static_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell, unroll=True)`, which is equivalent to this API\n",
      "W0809 11:12:33.634694 140177766897472 deprecation.py:506] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0809 11:12:33.646812 140177766897472 deprecation.py:506] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py:738: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0809 11:12:49.003790 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:135: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.\n",
      "\n",
      "W0809 11:12:49.038759 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:162: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n",
      "W0809 11:13:19.933787 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:329: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
    "emiLSTM = EMI_BasicLSTM(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
    "                        forgetBias=FORGET_BIAS, useDropout=USE_DROPOUT)\n",
    "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy')\n",
    "\n",
    "# Construct the graph\n",
    "g1 = tf.Graph()    \n",
    "with g1.as_default():\n",
    "    x_batch, y_batch = inputPipeline()\n",
    "    y_cap = emiLSTM(x_batch)\n",
    "    emiTrainer(y_cap, y_batch)\n",
    "    \n",
    "    \n",
    "with g1.as_default():\n",
    "    emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets initialize a new session with this graph and train a model. The saved model will be used later for restoring."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:48:32.894784Z",
     "start_time": "2018-08-19T11:47:41.258027Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0809 11:13:19.965286 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:369: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n",
      "W0809 11:13:19.982229 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:372: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "emiDriver.initializeSession(g1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As the output above indicates, the last restored model is `/tmp/model-lstm-1001`. That is, with `MODEL_PREFIX = '/tmp/model-lstm'` and `GLOBAL_STEP=1001`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:48:33.294431Z",
     "start_time": "2018-08-19T11:48:32.897376Z"
    }
   },
   "outputs": [],
   "source": [
    "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
    "    assert instanceOut.ndim == 2\n",
    "    classes = np.argmax(instanceOut, axis=1)\n",
    "    prob = np.max(instanceOut, axis=1)\n",
    "    index = np.where(prob >= minProb)[0]\n",
    "    if len(index) == 0:\n",
    "        assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
    "        return classes[-1], len(instanceOut) - 1\n",
    "    index = index[0]\n",
    "    return classes[index], index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:48:33.294431Z",
     "start_time": "2018-08-19T11:48:32.897376Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4min 49s, sys: 23.3 s, total: 5min 13s\n",
      "Wall time: 1min 1s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "k = 1\n",
    "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
    "                                                               minProb=0.99, keep_prob=1.0)\n",
    "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:48:33.294431Z",
     "start_time": "2018-08-19T11:48:32.897376Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy at k = 1: 0.087825\n"
     ]
    }
   ],
   "source": [
    "print('Accuracy at k = %d: %f' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "numpy.ndarray"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(x_test[:64])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Loading a Saved Graph into EMI-Driver\n",
    "\n",
    "We will reset the computation graph, load a saved graph into the current `EMI_Driver` and verify its outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_PREFIX='/home/sf/data/EdgeML/tf/examples/EMI-RNN//DREAMER/model-lstm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_PREFIX='/home/sf/data/EdgeML/tf/examples/EMI-RNN//WESAD/model-lstm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:48:38.811810Z",
     "start_time": "2018-08-19T11:48:33.296990Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0809 11:18:00.678148 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/utils.py:335: The name tf.train.import_meta_graph is deprecated. Please use tf.compat.v1.train.import_meta_graph instead.\n",
      "\n",
      "W0809 11:18:12.201538 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "emiDriver.loadSavedGraphToNewSession(MODEL_PREFIX, 1007)\n",
    "k = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(16863, 5, 350, 8)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "k=1\n",
    "x_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[[ 8.01725443e-01, -3.94587280e-02,  4.08067890e-01, ...,\n",
       "          -1.03118241e+00, -3.86102984e-02, -8.90117132e-01],\n",
       "         [ 7.85085095e-01, -3.75467141e-02,  4.11676023e-01, ...,\n",
       "          -1.02684695e+00, -4.58810468e-02, -8.87501712e-01],\n",
       "         [ 8.09289566e-01, -6.43183283e-02,  4.21297533e-01, ...,\n",
       "          -1.03204951e+00, -2.27468473e-02, -8.93479815e-01],\n",
       "         ...,\n",
       "         [ 7.51804399e-01,  2.26345970e-01,  3.34102320e-01, ...,\n",
       "          -1.03172435e+00, -3.73617860e-02, -2.82218773e-01],\n",
       "         [ 7.45753281e-01,  2.26345970e-01,  3.29892920e-01, ...,\n",
       "          -1.02847275e+00, -3.98098495e-02, -2.77735196e-01],\n",
       "         [ 7.51804399e-01,  2.41644361e-01,  3.28088943e-01, ...,\n",
       "          -1.03150757e+00, -7.14877903e-02, -2.72130724e-01]],\n",
       "\n",
       "        [[ 7.51804399e-01,  8.48387035e-02,  3.70784382e-01, ...,\n",
       "          -1.03204951e+00, -3.24901398e-02, -8.54995776e-01],\n",
       "         [ 7.41215168e-01,  9.43999129e-02,  3.72588538e-01, ...,\n",
       "          -1.02847275e+00, -4.58810468e-02, -8.53501250e-01],\n",
       "         [ 7.36677055e-01,  1.00136525e-01,  3.65372272e-01, ...,\n",
       "          -1.02619663e+00, -3.49382032e-02, -8.52753987e-01],\n",
       "         ...,\n",
       "         [ 7.21549712e-01,  2.37819763e-01,  3.67777694e-01, ...,\n",
       "          -1.03107403e+00, -4.71295591e-02,  8.73027249e-02],\n",
       "         [ 7.17011599e-01,  2.37819763e-01,  3.57554918e-01, ...,\n",
       "          -1.02728049e+00, -3.98098495e-02,  8.99181451e-02],\n",
       "         [ 7.30625937e-01,  2.30170568e-01,  3.29291655e-01, ...,\n",
       "          -1.03064048e+00, -8.12066021e-02,  9.62698796e-02]],\n",
       "\n",
       "        [[ 8.21390900e-01,  1.09697734e-01,  4.51364774e-01, ...,\n",
       "          -1.02890629e+00, -3.73617860e-02, -6.38289539e-01],\n",
       "         [ 8.25929013e-01,  1.00136525e-01,  4.60986283e-01, ...,\n",
       "          -1.02793081e+00, -7.38868924e-02, -6.30069647e-01],\n",
       "         [ 8.18365792e-01,  1.03961122e-01,  4.86242859e-01, ...,\n",
       "          -1.03107403e+00, -3.24901398e-02, -6.33432330e-01],\n",
       "         ...,\n",
       "         [ 7.29112932e-01,  2.30170568e-01,  3.09447190e-01, ...,\n",
       "          -1.03020693e+00, -4.71295591e-02,  4.11241187e-01],\n",
       "         [ 7.30625937e-01,  2.22521372e-01,  3.02231102e-01, ...,\n",
       "          -1.02522115e+00, -3.24901398e-02,  4.13856607e-01],\n",
       "         [ 7.30625937e-01,  2.03398953e-01,  3.01028391e-01, ...,\n",
       "          -1.03020693e+00, -4.22334323e-02,  4.23571024e-01]],\n",
       "\n",
       "        [[ 7.45753281e-01,  2.26345970e-01,  3.29892920e-01, ...,\n",
       "          -1.02847275e+00, -3.98098495e-02, -2.77735196e-01],\n",
       "         [ 7.51804399e-01,  2.41644361e-01,  3.28088943e-01, ...,\n",
       "          -1.03150757e+00, -7.14877903e-02, -2.72130724e-01],\n",
       "         [ 7.69957752e-01,  2.33995165e-01,  3.31095632e-01, ...,\n",
       "          -1.02933984e+00, -3.24901398e-02, -2.67647147e-01],\n",
       "         ...,\n",
       "         [ 7.00371250e-01,  1.61328948e-01,  2.48109997e-01, ...,\n",
       "          -1.02955661e+00, -2.88425253e-02,  7.00805555e-01],\n",
       "         [ 6.94320133e-01,  1.46031127e-01,  2.57731507e-01, ...,\n",
       "          -1.02663017e+00, -3.86102984e-02,  7.03794607e-01],\n",
       "         [ 6.97345241e-01,  1.34557334e-01,  2.66150306e-01, ...,\n",
       "          -1.02944823e+00, -3.49382032e-02,  7.15377182e-01]],\n",
       "\n",
       "        [[ 7.17011599e-01,  2.37819763e-01,  3.57554918e-01, ...,\n",
       "          -1.02728049e+00, -3.98098495e-02,  8.99181451e-02],\n",
       "         [ 7.30625937e-01,  2.30170568e-01,  3.29291655e-01, ...,\n",
       "          -1.03064048e+00, -8.12066021e-02,  9.62698796e-02],\n",
       "         [ 7.30625937e-01,  2.35907749e-01,  3.34703765e-01, ...,\n",
       "          -1.02749727e+00, -5.68483709e-02,  9.62698796e-02],\n",
       "         ...,\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]],\n",
       "\n",
       "\n",
       "       [[[ 2.25366173e-01, -6.37998299e-01,  2.01606543e+00, ...,\n",
       "           1.00692097e+00,  2.54545298e-01, -5.98684606e-01],\n",
       "         [ 2.35955404e-01, -6.37998299e-01,  1.99441699e+00, ...,\n",
       "           1.00713774e+00,  2.33761239e-01, -5.96442817e-01],\n",
       "         [ 2.34442399e-01, -6.24612492e-01,  1.98239024e+00, ...,\n",
       "           1.00724613e+00,  2.14152251e-01, -6.00179131e-01],\n",
       "         ...,\n",
       "         [ 5.47583007e-01,  1.06774322e+00,  1.37142302e+00, ...,\n",
       "           1.04886662e+00,  2.38657366e-01, -1.62542381e+00],\n",
       "         [ 5.35480772e-01,  1.06774322e+00,  1.37863929e+00, ...,\n",
       "           1.04973371e+00,  2.43529012e-01, -1.62542381e+00],\n",
       "         [ 5.40018885e-01,  1.09834000e+00,  1.38465266e+00, ...,\n",
       "           1.04713243e+00,  2.41105429e-01, -1.62579744e+00]],\n",
       "\n",
       "        [[ 3.53949945e-01, -1.29335350e-01,  1.86693176e+00, ...,\n",
       "           1.01179837e+00,  2.58241873e-01, -9.68953366e-01],\n",
       "         [ 3.55462950e-01, -1.56106964e-01,  1.86392489e+00, ...,\n",
       "           1.01223192e+00,  2.54545298e-01, -9.71568786e-01],\n",
       "         [ 3.49411832e-01, -1.52282366e-01,  1.85911441e+00, ...,\n",
       "           1.01158160e+00,  1.94592224e-01, -9.73063312e-01],\n",
       "         ...,\n",
       "         [ 5.40018885e-01,  1.45784535e+00,  1.32091023e+00, ...,\n",
       "           1.06664204e+00,  2.37408854e-01, -1.66876506e+00],\n",
       "         [ 5.43044894e-01,  1.46549455e+00,  1.32692360e+00, ...,\n",
       "           1.06664204e+00,  2.41105429e-01, -1.67025959e+00],\n",
       "         [ 5.18840424e-01,  1.44445954e+00,  1.33774782e+00, ...,\n",
       "           1.06707559e+00,  2.09280605e-01, -1.66951232e+00]],\n",
       "\n",
       "        [[ 2.93439669e-01,  3.44906790e-01,  1.50492198e+00, ...,\n",
       "           1.02458800e+00,  2.50897683e-01, -1.34183755e+00],\n",
       "         [ 2.97978684e-01,  3.39169608e-01,  1.49409776e+00, ...,\n",
       "           1.02491316e+00,  2.54545298e-01, -1.35304649e+00],\n",
       "         [ 3.01003792e-01,  3.60204611e-01,  1.49048945e+00, ...,\n",
       "           1.02480478e+00,  2.53345746e-01, -1.34968381e+00],\n",
       "         ...,\n",
       "         [ 1.01320514e-01,  1.67775602e+00,  1.18620873e+00, ...,\n",
       "           1.09298001e+00,  1.93392673e-01, -1.45131156e+00],\n",
       "         [ 9.52693960e-02,  1.64142376e+00,  1.18620873e+00, ...,\n",
       "           1.09395549e+00,  2.43529012e-01, -1.44794888e+00],\n",
       "         [ 8.77052734e-02,  1.62994940e+00,  1.18380331e+00, ...,\n",
       "           1.09438903e+00,  2.23944505e-01, -1.44832251e+00]],\n",
       "\n",
       "        [[ 5.35480772e-01,  1.06774322e+00,  1.37863929e+00, ...,\n",
       "           1.04973371e+00,  2.43529012e-01, -1.62542381e+00],\n",
       "         [ 5.40018885e-01,  1.09834000e+00,  1.38465266e+00, ...,\n",
       "           1.04713243e+00,  2.41105429e-01, -1.62579744e+00],\n",
       "         [ 5.43044894e-01,  1.12128645e+00,  1.39186893e+00, ...,\n",
       "           1.04962533e+00,  2.38657366e-01, -1.62841286e+00],\n",
       "         ...,\n",
       "         [ 4.82534619e-01,  2.75818692e+00,  1.23491790e+00, ...,\n",
       "           1.12538763e+00,  2.37408854e-01, -1.04554781e+00],\n",
       "         [ 4.55305040e-01,  2.76009951e+00,  1.24393779e+00, ...,\n",
       "           1.12625472e+00,  2.43529012e-01, -1.04031697e+00],\n",
       "         [ 4.22024343e-01,  2.74288853e+00,  1.25115406e+00, ...,\n",
       "           1.12603795e+00,  2.43529012e-01, -1.03583339e+00]],\n",
       "\n",
       "        [[ 5.43044894e-01,  1.46549455e+00,  1.32692360e+00, ...,\n",
       "           1.06664204e+00,  2.41105429e-01, -1.67025959e+00],\n",
       "         [ 5.18840424e-01,  1.44445954e+00,  1.33774782e+00, ...,\n",
       "           1.06707559e+00,  2.09280605e-01, -1.66951232e+00],\n",
       "         [ 4.99174967e-01,  1.45402132e+00,  1.33173445e+00, ...,\n",
       "           1.06696720e+00,  1.30991536e-01, -1.66540238e+00],\n",
       "         ...,\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]],\n",
       "\n",
       "\n",
       "       [[[ 2.97978684e-01,  3.44906790e-01, -3.25573536e-01, ...,\n",
       "           1.89200671e+00,  3.88185080e-01, -2.41866577e-01],\n",
       "         [ 3.01003792e-01,  3.52555985e-01, -3.30985646e-01, ...,\n",
       "           1.89103123e+00,  3.95529271e-01, -2.41492945e-01],\n",
       "         [ 3.07054910e-01,  3.35345580e-01, -3.38803179e-01, ...,\n",
       "           1.89449960e+00,  4.07843030e-01, -2.44481997e-01],\n",
       "         ...,\n",
       "         [ 3.61514068e-01,  3.12397994e-01, -2.44993144e-01, ...,\n",
       "           1.89850991e+00,  4.17659764e-01, -4.59693709e-01],\n",
       "         [ 3.78154416e-01,  3.41082192e-01, -2.14324638e-01, ...,\n",
       "           1.89731766e+00,  4.15187220e-01, -4.96309590e-01],\n",
       "         [ 3.79667421e-01,  3.60204611e-01, -2.04703128e-01, ...,\n",
       "           1.89731766e+00,  4.15187220e-01, -4.59693709e-01]],\n",
       "\n",
       "        [[ 2.94952674e-01,  3.67853806e-01, -3.43613844e-01, ...,\n",
       "           1.89092284e+00,  4.10291093e-01, -3.93934574e-01],\n",
       "         [ 2.97978684e-01,  3.60204611e-01, -3.43012578e-01, ...,\n",
       "           1.89124800e+00,  3.61207421e-01, -3.93187312e-01],\n",
       "         [ 2.97978684e-01,  3.56380013e-01, -3.42411133e-01, ...,\n",
       "           1.89113962e+00,  4.17659764e-01, -4.00659940e-01],\n",
       "         ...,\n",
       "         [ 2.67723095e-01,  3.79327599e-01, -3.41208422e-01, ...,\n",
       "           1.89341574e+00,  4.17659764e-01, -4.26066879e-01],\n",
       "         [ 2.66210091e-01,  3.67853806e-01, -3.37600468e-01, ...,\n",
       "           1.89341574e+00,  4.07843030e-01, -4.23451459e-01],\n",
       "         [ 2.58646870e-01,  3.67853806e-01, -3.32789623e-01, ...,\n",
       "           1.89384928e+00,  4.12739156e-01, -4.23077827e-01]],\n",
       "\n",
       "        [[ 2.82850439e-01,  3.37257594e-01, -3.54438065e-01, ...,\n",
       "           1.88918866e+00,  4.21331859e-01, -4.62682760e-01],\n",
       "         [ 2.97978684e-01,  3.23871787e-01, -3.36397757e-01, ...,\n",
       "           1.88951382e+00,  4.20107827e-01, -4.58946446e-01],\n",
       "         [ 2.97978684e-01,  3.41082192e-01, -3.40607156e-01, ...,\n",
       "           1.89027252e+00,  3.42797984e-01, -4.53715605e-01],\n",
       "         ...,\n",
       "         [ 3.13106027e-01,  3.75503002e-01, -2.82276652e-01, ...,\n",
       "           1.89428283e+00,  4.22580371e-01, -3.69274899e-01],\n",
       "         [ 3.05541905e-01,  3.71678404e-01, -3.02722382e-01, ...,\n",
       "           1.89471638e+00,  4.26252467e-01, -3.61055007e-01],\n",
       "         [ 3.10080018e-01,  3.52555985e-01, -3.26174802e-01, ...,\n",
       "           1.89558347e+00,  4.20107827e-01, -4.29055930e-01]],\n",
       "\n",
       "        [[ 3.78154416e-01,  3.41082192e-01, -2.14324638e-01, ...,\n",
       "           1.89731766e+00,  4.15187220e-01, -4.96309590e-01],\n",
       "         [ 3.79667421e-01,  3.60204611e-01, -2.04703128e-01, ...,\n",
       "           1.89731766e+00,  4.15187220e-01, -4.59693709e-01],\n",
       "         [ 3.73616303e-01,  3.75503002e-01, -1.84257397e-01, ...,\n",
       "           1.89710088e+00,  4.21331859e-01, -4.56704657e-01],\n",
       "         ...,\n",
       "         [ 3.10080018e-01,  3.48731388e-01, -3.35796311e-01, ...,\n",
       "           1.89178994e+00,  4.15187220e-01, -3.42373435e-01],\n",
       "         [ 3.01003792e-01,  3.35345580e-01, -3.37600468e-01, ...,\n",
       "           1.89200671e+00,  4.11515125e-01, -3.44241592e-01],\n",
       "         [ 3.01003792e-01,  3.46818804e-01, -3.30384201e-01, ...,\n",
       "           1.89254864e+00,  3.62431453e-01, -3.44615224e-01]],\n",
       "\n",
       "        [[ 2.66210091e-01,  3.67853806e-01, -3.37600468e-01, ...,\n",
       "           1.89341574e+00,  4.07843030e-01, -4.23451459e-01],\n",
       "         [ 2.58646870e-01,  3.67853806e-01, -3.32789623e-01, ...,\n",
       "           1.89384928e+00,  4.12739156e-01, -4.23077827e-01],\n",
       "         [ 2.69236100e-01,  3.46818804e-01, -3.29181490e-01, ...,\n",
       "           1.89309058e+00,  4.12739156e-01, -5.36661786e-01],\n",
       "         ...,\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]],\n",
       "\n",
       "\n",
       "       ...,\n",
       "\n",
       "\n",
       "       [[[-1.34335927e+00,  2.22521372e-01, -1.34004308e+00, ...,\n",
       "           1.46420443e+00,  1.65215463e-01,  2.04622998e-01],\n",
       "         [-1.34941039e+00,  2.07223551e-01, -1.32741488e+00, ...,\n",
       "           1.46463798e+00,  1.75032197e-01,  2.05370261e-01],\n",
       "         [-1.34487228e+00,  1.91925160e-01, -1.35026604e+00, ...,\n",
       "           1.46420443e+00,  1.62816361e-01,  2.04622998e-01],\n",
       "         ...,\n",
       "         [-1.34033417e+00,  1.84276535e-01, -1.35928610e+00, ...,\n",
       "           1.45575027e+00,  1.75032197e-01,  4.22823761e-01],\n",
       "         [-1.34033417e+00,  2.07223551e-01, -1.34725935e+00, ...,\n",
       "           1.45553350e+00,  1.75032197e-01,  4.22823761e-01],\n",
       "         [-1.33882116e+00,  2.16784761e-01, -1.35267146e+00, ...,\n",
       "           1.45575027e+00,  1.72584134e-01,  4.30670022e-01]],\n",
       "\n",
       "        [[-1.33277004e+00,  2.20609358e-01, -1.34425248e+00, ...,\n",
       "           1.46095283e+00,  1.26119890e-01,  2.51326929e-01],\n",
       "         [-1.33277004e+00,  1.88100563e-01, -1.36349568e+00, ...,\n",
       "           1.46095283e+00,  1.46928429e-01,  2.51326929e-01],\n",
       "         [-1.33579605e+00,  2.16784761e-01, -1.35447544e+00, ...,\n",
       "           1.46073606e+00,  1.59144266e-01,  2.51700560e-01],\n",
       "         ...,\n",
       "         [-1.35999962e+00,  1.95749758e-01, -1.35146875e+00, ...,\n",
       "           1.45444963e+00,  1.62816361e-01,  5.75265390e-01],\n",
       "         [-1.36453773e+00,  1.91925160e-01, -1.35146875e+00, ...,\n",
       "           1.45509995e+00,  1.45679917e-01,  5.81243493e-01],\n",
       "         [-1.35092340e+00,  1.80451937e-01, -1.34786061e+00, ...,\n",
       "           1.45531672e+00,  1.65215463e-01,  5.87221597e-01]],\n",
       "\n",
       "        [[-1.35697451e+00,  2.22521372e-01, -1.33523242e+00, ...,\n",
       "           1.45835155e+00,  1.79903844e-01,  3.14470643e-01],\n",
       "         [-1.34789739e+00,  2.26345970e-01, -1.35447544e+00, ...,\n",
       "           1.45835155e+00,  1.72584134e-01,  3.21196009e-01],\n",
       "         [-1.34033417e+00,  2.01486940e-01, -1.35928610e+00, ...,\n",
       "           1.45845994e+00,  1.66463975e-01,  3.24185060e-01],\n",
       "         ...,\n",
       "         [-1.29646424e+00,  1.86188549e-01, -1.32260404e+00, ...,\n",
       "           1.46073606e+00,  1.70136070e-01,  7.59465692e-01],\n",
       "         [-1.29948935e+00,  1.65153546e-01, -1.33823910e+00, ...,\n",
       "           1.46116960e+00,  1.55472170e-01,  7.63202007e-01],\n",
       "         [-1.31612969e+00,  1.91925160e-01, -1.34425248e+00, ...,\n",
       "           1.46095283e+00,  1.55472170e-01,  7.61707481e-01]],\n",
       "\n",
       "        [[-1.34033417e+00,  2.07223551e-01, -1.34725935e+00, ...,\n",
       "           1.45553350e+00,  1.75032197e-01,  4.22823761e-01],\n",
       "         [-1.33882116e+00,  2.16784761e-01, -1.35267146e+00, ...,\n",
       "           1.45575027e+00,  1.72584134e-01,  4.30670022e-01],\n",
       "         [-1.35394850e+00,  2.03398953e-01, -1.34605664e+00, ...,\n",
       "           1.45596704e+00,  1.45679917e-01,  4.21702867e-01],\n",
       "         ...,\n",
       "         [-1.39176731e+00,  2.01486940e-01, -1.34725935e+00, ...,\n",
       "           1.45444963e+00,  1.75032197e-01,  9.48896834e-01],\n",
       "         [-1.39328032e+00,  2.01486940e-01, -1.34485393e+00, ...,\n",
       "           1.45455802e+00,  1.72584134e-01,  9.60853040e-01],\n",
       "         [-1.40084444e+00,  2.03398953e-01, -1.34064453e+00, ...,\n",
       "           1.45455802e+00,  1.67688007e-01,  9.58984883e-01]],\n",
       "\n",
       "        [[-1.36453773e+00,  1.91925160e-01, -1.35146875e+00, ...,\n",
       "           1.45509995e+00,  1.45679917e-01,  5.81243493e-01],\n",
       "         [-1.35092340e+00,  1.80451937e-01, -1.34786061e+00, ...,\n",
       "           1.45531672e+00,  1.65215463e-01,  5.87221597e-01],\n",
       "         [-1.35697451e+00,  1.72802742e-01, -1.35146875e+00, ...,\n",
       "           1.45509995e+00,  1.26119890e-01,  5.89463385e-01],\n",
       "         ...,\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]],\n",
       "\n",
       "\n",
       "       [[[ 7.51804399e-01,  8.34446750e-01,  9.02373921e-01, ...,\n",
       "          -5.44851320e-01,  3.19516901e-01,  8.35664105e-02],\n",
       "         [ 7.44241178e-01,  8.38271918e-01,  9.11995252e-01, ...,\n",
       "          -5.45393254e-01,  3.04779559e-01,  9.10390394e-02],\n",
       "         [ 7.51804399e-01,  8.53569169e-01,  9.11995252e-01, ...,\n",
       "          -5.43767453e-01,  3.07252103e-01,  6.67529955e-02],\n",
       "         ...,\n",
       "         [ 6.80704892e-01,  1.02567322e+00,  9.45670805e-01, ...,\n",
       "          -5.44634547e-01,  3.15820325e-01, -2.07866116e-01],\n",
       "         [ 6.80704892e-01,  1.02949839e+00,  9.45670805e-01, ...,\n",
       "          -5.43767453e-01,  3.29284674e-01, -1.49579610e-01],\n",
       "         [ 6.76166779e-01,  9.96989019e-01,  9.27630317e-01, ...,\n",
       "          -5.75416368e-01,  3.29284674e-01, -1.47337822e-01]],\n",
       "\n",
       "        [[ 7.42728173e-01,  9.14761593e-01,  9.49278759e-01, ...,\n",
       "          -5.37697798e-01,  2.47201107e-01, -4.27210174e-02],\n",
       "         [ 7.45753281e-01,  9.10937565e-01,  9.34245139e-01, ...,\n",
       "          -5.46043574e-01,  3.09675686e-01,  1.29500677e-02],\n",
       "         [ 7.33651045e-01,  9.07113537e-01,  9.42062493e-01, ...,\n",
       "          -5.44851320e-01,  3.09675686e-01,  1.14555419e-02],\n",
       "         ...,\n",
       "         [ 7.01883354e-01,  1.03714758e+00,  9.80548532e-01, ...,\n",
       "          -5.41708106e-01,  3.32981250e-01, -1.70876603e-01],\n",
       "         [ 7.06421466e-01,  1.02949839e+00,  9.72129733e-01, ...,\n",
       "          -5.40407466e-01,  3.21964964e-01, -1.70129340e-01],\n",
       "         [ 7.12472584e-01,  1.04097161e+00,  9.73933710e-01, ...,\n",
       "          -5.33145557e-01,  3.02355976e-01, -1.70129340e-01]],\n",
       "\n",
       "        [[ 7.36677055e-01,  9.95077575e-01,  9.51684182e-01, ...,\n",
       "          -5.42900360e-01,  3.24413028e-01, -7.33587958e-02],\n",
       "         [ 7.27600829e-01,  9.68305961e-01,  9.58299003e-01, ...,\n",
       "          -5.44092613e-01,  3.98001815e-01, -7.44796902e-02],\n",
       "         [ 7.10960481e-01,  9.72129989e-01,  9.61305512e-01, ...,\n",
       "          -5.41816493e-01,  3.31757218e-01, -7.41060587e-02],\n",
       "         ...,\n",
       "         [ 7.36677055e-01,  1.00272677e+00,  9.61906957e-01, ...,\n",
       "          -5.44742933e-01,  3.26861091e-01, -1.64524868e-01],\n",
       "         [ 7.35164050e-01,  9.98901603e-01,  9.54690691e-01, ...,\n",
       "          -5.43333906e-01,  2.98659401e-01, -1.61535817e-01],\n",
       "         [ 7.36677055e-01,  9.98901603e-01,  9.60704426e-01, ...,\n",
       "          -5.42683586e-01,  3.23164516e-01, -1.60788554e-01]],\n",
       "\n",
       "        [[ 6.80704892e-01,  1.02949839e+00,  9.45670805e-01, ...,\n",
       "          -5.43767453e-01,  3.29284674e-01, -1.49579610e-01],\n",
       "         [ 6.76166779e-01,  9.96989019e-01,  9.27630317e-01, ...,\n",
       "          -5.75416368e-01,  3.29284674e-01, -1.47337822e-01],\n",
       "         [ 6.82217897e-01,  1.01037597e+00,  9.37251648e-01, ...,\n",
       "          -5.43442293e-01,  2.85170571e-01, -1.49579610e-01],\n",
       "         ...,\n",
       "         [ 7.03396358e-01,  1.00655080e+00,  9.77542023e-01, ...,\n",
       "          -5.40299079e-01,  3.25637060e-01, -2.00019855e-01],\n",
       "         [ 7.26087824e-01,  9.96989019e-01,  9.84156845e-01, ...,\n",
       "          -5.45826800e-01,  3.17068838e-01, -1.97778067e-01],\n",
       "         [ 7.21549712e-01,  9.98901603e-01,  9.85359377e-01, ...,\n",
       "          -5.46043574e-01,  3.32981250e-01, -1.99272592e-01]],\n",
       "\n",
       "        [[ 7.06421466e-01,  1.02949839e+00,  9.72129733e-01, ...,\n",
       "          -5.40407466e-01,  3.21964964e-01, -1.70129340e-01],\n",
       "         [ 7.12472584e-01,  1.04097161e+00,  9.73933710e-01, ...,\n",
       "          -5.33145557e-01,  3.02355976e-01, -1.70129340e-01],\n",
       "         [ 7.12472584e-01,  1.04097161e+00,  9.56495026e-01, ...,\n",
       "          -5.44959707e-01,  3.13372262e-01, -1.65272131e-01],\n",
       "         ...,\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]],\n",
       "\n",
       "\n",
       "       [[[ 1.04376564e+00,  1.32398728e+00, -4.93348782e-01, ...,\n",
       "           1.24201171e+00, -2.46573287e-01, -1.40572852e+00],\n",
       "         [ 1.14360773e+00,  1.28765389e+00, -4.77112450e-01, ...,\n",
       "           1.24331235e+00, -2.35654924e-01, -1.41544294e+00],\n",
       "         [ 1.24496193e+00,  1.23984727e+00, -4.57869430e-01, ...,\n",
       "           1.24331235e+00, -2.32007309e-01, -1.41768473e+00],\n",
       "         ...,\n",
       "         [ 3.43360714e-01, -9.86030993e-01,  1.62117675e-01, ...,\n",
       "           1.21870857e+00, -2.28384176e-01, -1.19836307e+00],\n",
       "         [ 3.43360714e-01, -9.38224377e-01,  1.65124363e-01, ...,\n",
       "           1.22033437e+00, -2.51444933e-01, -1.19201134e+00],\n",
       "         [ 3.49411832e-01, -9.13364776e-01,  1.59110807e-01, ...,\n",
       "           1.22488661e+00, -2.34430892e-01, -1.18453871e+00]],\n",
       "\n",
       "        [[-6.61107696e-01, -6.99190723e-01,  3.32298343e-01, ...,\n",
       "           1.23073949e+00, -2.17441332e-01, -1.86716336e+00],\n",
       "         [-7.06490629e-01, -7.46997340e-01,  3.34703765e-01, ...,\n",
       "           1.23073949e+00, -2.19840434e-01, -1.87127330e+00],\n",
       "         [-7.18591962e-01, -7.98628554e-01,  3.22676833e-01, ...,\n",
       "           1.23052272e+00, -2.22288498e-01, -1.87127330e+00],\n",
       "         ...,\n",
       "         [ 4.34126579e-01, -3.70281017e-01,  3.55625774e-04, ...,\n",
       "           1.23355755e+00, -2.35654924e-01, -5.60574198e-01],\n",
       "         [ 5.06738188e-01, -3.51158029e-01,  1.47879802e-02, ...,\n",
       "           1.23355755e+00, -2.23512529e-01, -5.60574198e-01],\n",
       "         [ 5.73299581e-01, -3.24386415e-01,  2.08013566e-02, ...,\n",
       "           1.23377432e+00, -2.18640883e-01, -5.54596095e-01]],\n",
       "\n",
       "        [[ 4.52279030e-01, -7.38795377e-02,  3.88824870e-01, ...,\n",
       "           1.21968405e+00, -2.23512529e-01, -1.82793206e+00],\n",
       "         [ 4.12947216e-01, -9.87391379e-02,  3.99649091e-01, ...,\n",
       "           1.21827502e+00, -2.27135663e-01, -1.82158032e+00],\n",
       "         [ 3.46386724e-01, -1.31247934e-01,  3.98446380e-01, ...,\n",
       "           1.21914212e+00, -2.19840434e-01, -1.81896490e+00],\n",
       "         ...,\n",
       "         [ 9.36360333e-01,  1.26908139e-01,  6.08315799e-01, ...,\n",
       "           1.23789301e+00, -2.36854475e-01, -3.55076904e-01],\n",
       "         [ 7.45753281e-01,  9.05753151e-02,  6.40187196e-01, ...,\n",
       "           1.23767624e+00, -2.25936112e-01, -3.67406742e-01],\n",
       "         [ 5.33967767e-01,  3.89441007e-02,  6.90098901e-01, ...,\n",
       "           1.23756785e+00, -2.19840434e-01, -3.53582379e-01]],\n",
       "\n",
       "        [[ 3.43360714e-01, -9.38224377e-01,  1.65124363e-01, ...,\n",
       "           1.22033437e+00, -2.51444933e-01, -1.19201134e+00],\n",
       "         [ 3.49411832e-01, -9.13364776e-01,  1.59110807e-01, ...,\n",
       "           1.22488661e+00, -2.34430892e-01, -1.18453871e+00],\n",
       "         [ 3.75128406e-01, -9.01890983e-01,  1.75347139e-01, ...,\n",
       "           1.22065953e+00, -2.23512529e-01, -1.18042876e+00],\n",
       "         ...,\n",
       "         [ 1.72904323e+00,  5.01713017e-01,  1.43475921e-01, ...,\n",
       "           1.24797298e+00, -2.17441332e-01, -4.38023085e-01],\n",
       "         [ 1.73811946e+00,  5.32308659e-01,  1.42273210e-01, ...,\n",
       "           1.24873168e+00, -2.25936112e-01, -4.38770348e-01],\n",
       "         [ 1.71089078e+00,  5.55256245e-01,  1.37462365e-01, ...,\n",
       "           1.24851491e+00, -2.30783278e-01, -4.41759399e-01]],\n",
       "\n",
       "        [[ 5.06738188e-01, -3.51158029e-01,  1.47879802e-02, ...,\n",
       "           1.23355755e+00, -2.23512529e-01, -5.60574198e-01],\n",
       "         [ 5.73299581e-01, -3.24386415e-01,  2.08013566e-02, ...,\n",
       "           1.23377432e+00, -2.18640883e-01, -5.54596095e-01],\n",
       "         [ 6.24733631e-01, -3.03351412e-01,  3.16255776e-02, ...,\n",
       "           1.23377432e+00, -2.29583727e-01, -5.49738886e-01],\n",
       "         ...,\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,\n",
       "           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_test[:64]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-19T11:48:38.811810Z",
     "start_time": "2018-08-19T11:48:33.296990Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4min 53s, sys: 23.4 s, total: 5min 16s\n",
      "Wall time: 1min 2s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
    "                                                               minProb=0.99, keep_prob=1.0)\n",
    "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy at k = 1: 0.854771\n"
     ]
    }
   ],
   "source": [
    "print('Accuracy at k = %d: %f' % (k,  np.mean((bagPredictions == BAG_TEST).astype(int))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}