1074 lines (1073 with data), 51.8 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Initializing and Restoring EMI-RNN Graphs\n",
"\n",
"The *EMI-RNN* implementation supports four forms of initialization/restoring:\n",
"1. An entirely new graph can be constructed with randomly initialized weights.\n",
"2. A saved graph can be loaded into the current `EMI_Driver`.\n",
"2. An entirely new graph can be constructed with weights initialized from a saved graph. This behavior is essentially restoration of a saved graph.\n",
"3. (*Experimental*) Initializing/Restoring using numpy matrices. \n",
"\n",
"All three methods will be illustrated in this notebook. This notebook uses the HAR dataset and builds on the [EMI LSTM example.ipynb](00_emi_lstm_example.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:47:35.576928Z",
"start_time": "2018-08-19T11:47:34.670184Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import os\n",
"import sys\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"# To include edgeml in python path\n",
"os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
"\n",
"# MI-RNN and EMI-RNN imports\n",
"from edgeml.graph.rnn import EMI_DataPipeline\n",
"from edgeml.graph.rnn import EMI_BasicLSTM\n",
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
"import edgeml.utils"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us set up some network parameters for the computation graph."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:47:35.590781Z",
"start_time": "2018-08-19T11:47:35.578914Z"
}
},
"outputs": [],
"source": [
"# Network parameters for our LSTM + FC Layer\n",
"NUM_HIDDEN = 128\n",
"NUM_TIMESTEPS = 64\n",
"NUM_FEATS = 16\n",
"FORGET_BIAS = 1.0\n",
"NUM_OUTPUT = 5\n",
"USE_DROPOUT = True\n",
"KEEP_PROB = 0.75\n",
"\n",
"# For dataset API\n",
"PREFETCH_NUM = 5\n",
"BATCH_SIZE = 32\n",
"\n",
"# Number of epochs in *one iteration*\n",
"NUM_EPOCHS = 3\n",
"# Number of iterations in *one round*. After each iteration,\n",
"# the model is dumped to disk. At the end of the current\n",
"# round, the best model among all the dumped models in the\n",
"# current round is picked up..\n",
"NUM_ITER = 4\n",
"# A round consists of multiple training iterations and a belief\n",
"# update step using the best model from all of these iterations\n",
"NUM_ROUNDS = 10\n",
"\n",
"# A staging direcory to store models\n",
"dataset = 'DREAMER'\n",
"MODEL_PREFIX = '/home/sf/data/EdgeML/tf/examples/EMI-RNN' + dataset + '/model-lstm'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Network parameters for our LSTM + FC Layer\n",
"NUM_HIDDEN = 128\n",
"NUM_TIMESTEPS = 350\n",
"NUM_FEATS = 8\n",
"FORGET_BIAS = 1.0\n",
"NUM_OUTPUT = 5\n",
"USE_DROPOUT = True\n",
"KEEP_PROB = 0.75\n",
"\n",
"# For dataset API\n",
"PREFETCH_NUM = 5\n",
"BATCH_SIZE = 32\n",
"\n",
"# Number of epochs in *one iteration*\n",
"NUM_EPOCHS = 3\n",
"# Number of iterations in *one round*. After each iteration,\n",
"# the model is dumped to disk. At the end of the current\n",
"# round, the best model among all the dumped models in the\n",
"# current round is picked up..\n",
"NUM_ITER = 4\n",
"# A round consists of multiple training iterations and a belief\n",
"# update step using the best model from all of these iterations\n",
"NUM_ROUNDS = 10\n",
"\n",
"# A staging direcory to store models\n",
"dataset = 'WESAD'\n",
"MODEL_PREFIX = '/home/sf/data/EdgeML/tf/examples/EMI-RNN' + dataset + '/model-lstm'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:47:35.694831Z",
"start_time": "2018-08-19T11:47:35.592516Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train shape is: (61735, 5, 64, 16)\n",
"y_train shape is: (61735, 5, 5)\n",
"x_test shape is: (6860, 5, 64, 16)\n",
"y_test shape is: (6860, 5, 5)\n"
]
}
],
"source": [
"path=\"/home/sf/data/DREAMER/Dominance/64_16/\"\n",
"x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
"x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
"x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
"\n",
"# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
"# step of EMI/MI RNN\n",
"BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
"BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
"BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
"NUM_SUBINSTANCE = x_train.shape[1]\n",
"print(\"x_train shape is:\", x_train.shape)\n",
"print(\"y_train shape is:\", y_train.shape)\n",
"print(\"x_test shape is:\", x_val.shape)\n",
"print(\"y_test shape is:\", y_val.shape)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_test shape is: (16863, 5, 350, 8)\n",
"y_test shape is: (16863, 5, 5)\n"
]
}
],
"source": [
"path=\"/home/sf/data/WESAD/350_116/\"\n",
"#x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
"x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
"#x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
"\n",
"# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
"# step of EMI/MI RNN\n",
"BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
"#BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
"#BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
"NUM_SUBINSTANCE = x_test.shape[1]\n",
"#print(\"x_train shape is:\", x_train.shape)\n",
"#print(\"y_train shape is:\", y_train.shape)\n",
"print(\"x_test shape is:\", x_test.shape)\n",
"print(\"y_test shape is:\", y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:47:35.739003Z",
"start_time": "2018-08-19T11:47:35.696723Z"
}
},
"outputs": [],
"source": [
"# Define the linear secondary classifier\n",
"def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
" W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
" B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
" y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
" self.output = y_cap\n",
" self.graphCreated = True\n",
" \n",
"def addExtendedAssignOps(self, graph, W_val=None, B_val=None):\n",
" W1 = graph.get_tensor_by_name('W1:0')\n",
" B1 = graph.get_tensor_by_name('B1:0')\n",
" W1_op = tf.assign(W1, W_val)\n",
" B1_op = tf.assign(B1, B_val)\n",
" self.assignOps.extend([W1_op, B1_op])\n",
"\n",
"def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
" y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
" self.output = y_cap\n",
" self.graphCreated = True\n",
" \n",
"def feedDictFunc(self, keep_prob, **kwargs):\n",
" feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
" return feedDict\n",
" \n",
"EMI_BasicLSTM._createExtendedGraph = createExtendedGraph\n",
"EMI_BasicLSTM._restoreExtendedGraph = restoreExtendedGraph\n",
"EMI_BasicLSTM.addExtendedAssignOps = addExtendedAssignOps\n",
"\n",
"if USE_DROPOUT is True:\n",
" EMI_Driver.feedDictFunc = feedDictFunc"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T09:34:06.288012Z",
"start_time": "2018-08-19T09:34:06.285286Z"
}
},
"source": [
"## 1. Initializing a New Computation Graph\n",
"\n",
"In the most common use cases, a new EMI-RNN graph would be created and trained"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:49:55.326002Z",
"start_time": "2018-08-19T11:49:50.568621Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: Logging before flag parsing goes to stderr.\n",
"W0809 11:12:32.864859 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1141: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
"\n",
"W0809 11:12:32.881156 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1153: The name tf.data.Iterator is deprecated. Please use tf.compat.v1.data.Iterator instead.\n",
"\n",
"W0809 11:12:32.882301 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1153: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.data.get_output_types(dataset)`.\n",
"W0809 11:12:32.883116 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1154: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.data.get_output_shapes(dataset)`.\n",
"W0809 11:12:32.887950 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:348: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.data.get_output_types(iterator)`.\n",
"W0809 11:12:32.888927 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:349: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.data.get_output_shapes(iterator)`.\n",
"W0809 11:12:32.889680 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py:351: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.data.get_output_classes(iterator)`.\n",
"W0809 11:12:32.892827 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1159: The name tf.add_to_collection is deprecated. Please use tf.compat.v1.add_to_collection instead.\n",
"\n",
"W0809 11:12:32.899183 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1396: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
"W0809 11:12:33.587531 140177766897472 lazy_loader.py:50] \n",
"The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
"For more information, please see:\n",
" * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
" * https://github.com/tensorflow/addons\n",
" * https://github.com/tensorflow/io (for I/O related ops)\n",
"If you depend on functionality not listed there, please file an issue.\n",
"\n",
"W0809 11:12:33.592689 140177766897472 deprecation.py:323] From /home/sf/data/EdgeML/tf/edgeml/graph/rnn.py:1404: static_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use `keras.layers.RNN(cell, unroll=True)`, which is equivalent to this API\n",
"W0809 11:12:33.634694 140177766897472 deprecation.py:506] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Call initializer instance with the dtype argument instead of passing it to the constructor\n",
"W0809 11:12:33.646812 140177766897472 deprecation.py:506] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py:738: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Call initializer instance with the dtype argument instead of passing it to the constructor\n",
"W0809 11:12:49.003790 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:135: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.\n",
"\n",
"W0809 11:12:49.038759 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:162: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
"\n",
"W0809 11:13:19.933787 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:329: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.\n",
"\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
"emiLSTM = EMI_BasicLSTM(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
" forgetBias=FORGET_BIAS, useDropout=USE_DROPOUT)\n",
"emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy')\n",
"\n",
"# Construct the graph\n",
"g1 = tf.Graph() \n",
"with g1.as_default():\n",
" x_batch, y_batch = inputPipeline()\n",
" y_cap = emiLSTM(x_batch)\n",
" emiTrainer(y_cap, y_batch)\n",
" \n",
" \n",
"with g1.as_default():\n",
" emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets initialize a new session with this graph and train a model. The saved model will be used later for restoring."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:48:32.894784Z",
"start_time": "2018-08-19T11:47:41.258027Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"W0809 11:13:19.965286 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:369: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
"\n",
"W0809 11:13:19.982229 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/trainer/emirnnTrainer.py:372: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
"\n"
]
}
],
"source": [
"emiDriver.initializeSession(g1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the output above indicates, the last restored model is `/tmp/model-lstm-1001`. That is, with `MODEL_PREFIX = '/tmp/model-lstm'` and `GLOBAL_STEP=1001`."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:48:33.294431Z",
"start_time": "2018-08-19T11:48:32.897376Z"
}
},
"outputs": [],
"source": [
"def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
" assert instanceOut.ndim == 2\n",
" classes = np.argmax(instanceOut, axis=1)\n",
" prob = np.max(instanceOut, axis=1)\n",
" index = np.where(prob >= minProb)[0]\n",
" if len(index) == 0:\n",
" assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
" return classes[-1], len(instanceOut) - 1\n",
" index = index[0]\n",
" return classes[index], index"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:48:33.294431Z",
"start_time": "2018-08-19T11:48:32.897376Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4min 49s, sys: 23.3 s, total: 5min 13s\n",
"Wall time: 1min 1s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"k = 1\n",
"predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
" minProb=0.99, keep_prob=1.0)\n",
"bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:48:33.294431Z",
"start_time": "2018-08-19T11:48:32.897376Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy at k = 1: 0.087825\n"
]
}
],
"source": [
"print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"numpy.ndarray"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(x_test[:64])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Loading a Saved Graph into EMI-Driver\n",
"\n",
"We will reset the computation graph, load a saved graph into the current `EMI_Driver` and verify its outputs."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"MODEL_PREFIX='/home/sf/data/EdgeML/tf/examples/EMI-RNN//DREAMER/model-lstm'"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"MODEL_PREFIX='/home/sf/data/EdgeML/tf/examples/EMI-RNN//WESAD/model-lstm'"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:48:38.811810Z",
"start_time": "2018-08-19T11:48:33.296990Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"W0809 11:18:00.678148 140177766897472 deprecation_wrapper.py:119] From /home/sf/data/EdgeML/tf/edgeml/utils.py:335: The name tf.train.import_meta_graph is deprecated. Please use tf.compat.v1.train.import_meta_graph instead.\n",
"\n",
"W0809 11:18:12.201538 140177766897472 deprecation.py:323] From /home/sf/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use standard file APIs to check for files with this prefix.\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"emiDriver.loadSavedGraphToNewSession(MODEL_PREFIX, 1007)\n",
"k = 2"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(16863, 5, 350, 8)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k=1\n",
"x_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[ 8.01725443e-01, -3.94587280e-02, 4.08067890e-01, ...,\n",
" -1.03118241e+00, -3.86102984e-02, -8.90117132e-01],\n",
" [ 7.85085095e-01, -3.75467141e-02, 4.11676023e-01, ...,\n",
" -1.02684695e+00, -4.58810468e-02, -8.87501712e-01],\n",
" [ 8.09289566e-01, -6.43183283e-02, 4.21297533e-01, ...,\n",
" -1.03204951e+00, -2.27468473e-02, -8.93479815e-01],\n",
" ...,\n",
" [ 7.51804399e-01, 2.26345970e-01, 3.34102320e-01, ...,\n",
" -1.03172435e+00, -3.73617860e-02, -2.82218773e-01],\n",
" [ 7.45753281e-01, 2.26345970e-01, 3.29892920e-01, ...,\n",
" -1.02847275e+00, -3.98098495e-02, -2.77735196e-01],\n",
" [ 7.51804399e-01, 2.41644361e-01, 3.28088943e-01, ...,\n",
" -1.03150757e+00, -7.14877903e-02, -2.72130724e-01]],\n",
"\n",
" [[ 7.51804399e-01, 8.48387035e-02, 3.70784382e-01, ...,\n",
" -1.03204951e+00, -3.24901398e-02, -8.54995776e-01],\n",
" [ 7.41215168e-01, 9.43999129e-02, 3.72588538e-01, ...,\n",
" -1.02847275e+00, -4.58810468e-02, -8.53501250e-01],\n",
" [ 7.36677055e-01, 1.00136525e-01, 3.65372272e-01, ...,\n",
" -1.02619663e+00, -3.49382032e-02, -8.52753987e-01],\n",
" ...,\n",
" [ 7.21549712e-01, 2.37819763e-01, 3.67777694e-01, ...,\n",
" -1.03107403e+00, -4.71295591e-02, 8.73027249e-02],\n",
" [ 7.17011599e-01, 2.37819763e-01, 3.57554918e-01, ...,\n",
" -1.02728049e+00, -3.98098495e-02, 8.99181451e-02],\n",
" [ 7.30625937e-01, 2.30170568e-01, 3.29291655e-01, ...,\n",
" -1.03064048e+00, -8.12066021e-02, 9.62698796e-02]],\n",
"\n",
" [[ 8.21390900e-01, 1.09697734e-01, 4.51364774e-01, ...,\n",
" -1.02890629e+00, -3.73617860e-02, -6.38289539e-01],\n",
" [ 8.25929013e-01, 1.00136525e-01, 4.60986283e-01, ...,\n",
" -1.02793081e+00, -7.38868924e-02, -6.30069647e-01],\n",
" [ 8.18365792e-01, 1.03961122e-01, 4.86242859e-01, ...,\n",
" -1.03107403e+00, -3.24901398e-02, -6.33432330e-01],\n",
" ...,\n",
" [ 7.29112932e-01, 2.30170568e-01, 3.09447190e-01, ...,\n",
" -1.03020693e+00, -4.71295591e-02, 4.11241187e-01],\n",
" [ 7.30625937e-01, 2.22521372e-01, 3.02231102e-01, ...,\n",
" -1.02522115e+00, -3.24901398e-02, 4.13856607e-01],\n",
" [ 7.30625937e-01, 2.03398953e-01, 3.01028391e-01, ...,\n",
" -1.03020693e+00, -4.22334323e-02, 4.23571024e-01]],\n",
"\n",
" [[ 7.45753281e-01, 2.26345970e-01, 3.29892920e-01, ...,\n",
" -1.02847275e+00, -3.98098495e-02, -2.77735196e-01],\n",
" [ 7.51804399e-01, 2.41644361e-01, 3.28088943e-01, ...,\n",
" -1.03150757e+00, -7.14877903e-02, -2.72130724e-01],\n",
" [ 7.69957752e-01, 2.33995165e-01, 3.31095632e-01, ...,\n",
" -1.02933984e+00, -3.24901398e-02, -2.67647147e-01],\n",
" ...,\n",
" [ 7.00371250e-01, 1.61328948e-01, 2.48109997e-01, ...,\n",
" -1.02955661e+00, -2.88425253e-02, 7.00805555e-01],\n",
" [ 6.94320133e-01, 1.46031127e-01, 2.57731507e-01, ...,\n",
" -1.02663017e+00, -3.86102984e-02, 7.03794607e-01],\n",
" [ 6.97345241e-01, 1.34557334e-01, 2.66150306e-01, ...,\n",
" -1.02944823e+00, -3.49382032e-02, 7.15377182e-01]],\n",
"\n",
" [[ 7.17011599e-01, 2.37819763e-01, 3.57554918e-01, ...,\n",
" -1.02728049e+00, -3.98098495e-02, 8.99181451e-02],\n",
" [ 7.30625937e-01, 2.30170568e-01, 3.29291655e-01, ...,\n",
" -1.03064048e+00, -8.12066021e-02, 9.62698796e-02],\n",
" [ 7.30625937e-01, 2.35907749e-01, 3.34703765e-01, ...,\n",
" -1.02749727e+00, -5.68483709e-02, 9.62698796e-02],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],\n",
"\n",
"\n",
" [[[ 2.25366173e-01, -6.37998299e-01, 2.01606543e+00, ...,\n",
" 1.00692097e+00, 2.54545298e-01, -5.98684606e-01],\n",
" [ 2.35955404e-01, -6.37998299e-01, 1.99441699e+00, ...,\n",
" 1.00713774e+00, 2.33761239e-01, -5.96442817e-01],\n",
" [ 2.34442399e-01, -6.24612492e-01, 1.98239024e+00, ...,\n",
" 1.00724613e+00, 2.14152251e-01, -6.00179131e-01],\n",
" ...,\n",
" [ 5.47583007e-01, 1.06774322e+00, 1.37142302e+00, ...,\n",
" 1.04886662e+00, 2.38657366e-01, -1.62542381e+00],\n",
" [ 5.35480772e-01, 1.06774322e+00, 1.37863929e+00, ...,\n",
" 1.04973371e+00, 2.43529012e-01, -1.62542381e+00],\n",
" [ 5.40018885e-01, 1.09834000e+00, 1.38465266e+00, ...,\n",
" 1.04713243e+00, 2.41105429e-01, -1.62579744e+00]],\n",
"\n",
" [[ 3.53949945e-01, -1.29335350e-01, 1.86693176e+00, ...,\n",
" 1.01179837e+00, 2.58241873e-01, -9.68953366e-01],\n",
" [ 3.55462950e-01, -1.56106964e-01, 1.86392489e+00, ...,\n",
" 1.01223192e+00, 2.54545298e-01, -9.71568786e-01],\n",
" [ 3.49411832e-01, -1.52282366e-01, 1.85911441e+00, ...,\n",
" 1.01158160e+00, 1.94592224e-01, -9.73063312e-01],\n",
" ...,\n",
" [ 5.40018885e-01, 1.45784535e+00, 1.32091023e+00, ...,\n",
" 1.06664204e+00, 2.37408854e-01, -1.66876506e+00],\n",
" [ 5.43044894e-01, 1.46549455e+00, 1.32692360e+00, ...,\n",
" 1.06664204e+00, 2.41105429e-01, -1.67025959e+00],\n",
" [ 5.18840424e-01, 1.44445954e+00, 1.33774782e+00, ...,\n",
" 1.06707559e+00, 2.09280605e-01, -1.66951232e+00]],\n",
"\n",
" [[ 2.93439669e-01, 3.44906790e-01, 1.50492198e+00, ...,\n",
" 1.02458800e+00, 2.50897683e-01, -1.34183755e+00],\n",
" [ 2.97978684e-01, 3.39169608e-01, 1.49409776e+00, ...,\n",
" 1.02491316e+00, 2.54545298e-01, -1.35304649e+00],\n",
" [ 3.01003792e-01, 3.60204611e-01, 1.49048945e+00, ...,\n",
" 1.02480478e+00, 2.53345746e-01, -1.34968381e+00],\n",
" ...,\n",
" [ 1.01320514e-01, 1.67775602e+00, 1.18620873e+00, ...,\n",
" 1.09298001e+00, 1.93392673e-01, -1.45131156e+00],\n",
" [ 9.52693960e-02, 1.64142376e+00, 1.18620873e+00, ...,\n",
" 1.09395549e+00, 2.43529012e-01, -1.44794888e+00],\n",
" [ 8.77052734e-02, 1.62994940e+00, 1.18380331e+00, ...,\n",
" 1.09438903e+00, 2.23944505e-01, -1.44832251e+00]],\n",
"\n",
" [[ 5.35480772e-01, 1.06774322e+00, 1.37863929e+00, ...,\n",
" 1.04973371e+00, 2.43529012e-01, -1.62542381e+00],\n",
" [ 5.40018885e-01, 1.09834000e+00, 1.38465266e+00, ...,\n",
" 1.04713243e+00, 2.41105429e-01, -1.62579744e+00],\n",
" [ 5.43044894e-01, 1.12128645e+00, 1.39186893e+00, ...,\n",
" 1.04962533e+00, 2.38657366e-01, -1.62841286e+00],\n",
" ...,\n",
" [ 4.82534619e-01, 2.75818692e+00, 1.23491790e+00, ...,\n",
" 1.12538763e+00, 2.37408854e-01, -1.04554781e+00],\n",
" [ 4.55305040e-01, 2.76009951e+00, 1.24393779e+00, ...,\n",
" 1.12625472e+00, 2.43529012e-01, -1.04031697e+00],\n",
" [ 4.22024343e-01, 2.74288853e+00, 1.25115406e+00, ...,\n",
" 1.12603795e+00, 2.43529012e-01, -1.03583339e+00]],\n",
"\n",
" [[ 5.43044894e-01, 1.46549455e+00, 1.32692360e+00, ...,\n",
" 1.06664204e+00, 2.41105429e-01, -1.67025959e+00],\n",
" [ 5.18840424e-01, 1.44445954e+00, 1.33774782e+00, ...,\n",
" 1.06707559e+00, 2.09280605e-01, -1.66951232e+00],\n",
" [ 4.99174967e-01, 1.45402132e+00, 1.33173445e+00, ...,\n",
" 1.06696720e+00, 1.30991536e-01, -1.66540238e+00],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],\n",
"\n",
"\n",
" [[[ 2.97978684e-01, 3.44906790e-01, -3.25573536e-01, ...,\n",
" 1.89200671e+00, 3.88185080e-01, -2.41866577e-01],\n",
" [ 3.01003792e-01, 3.52555985e-01, -3.30985646e-01, ...,\n",
" 1.89103123e+00, 3.95529271e-01, -2.41492945e-01],\n",
" [ 3.07054910e-01, 3.35345580e-01, -3.38803179e-01, ...,\n",
" 1.89449960e+00, 4.07843030e-01, -2.44481997e-01],\n",
" ...,\n",
" [ 3.61514068e-01, 3.12397994e-01, -2.44993144e-01, ...,\n",
" 1.89850991e+00, 4.17659764e-01, -4.59693709e-01],\n",
" [ 3.78154416e-01, 3.41082192e-01, -2.14324638e-01, ...,\n",
" 1.89731766e+00, 4.15187220e-01, -4.96309590e-01],\n",
" [ 3.79667421e-01, 3.60204611e-01, -2.04703128e-01, ...,\n",
" 1.89731766e+00, 4.15187220e-01, -4.59693709e-01]],\n",
"\n",
" [[ 2.94952674e-01, 3.67853806e-01, -3.43613844e-01, ...,\n",
" 1.89092284e+00, 4.10291093e-01, -3.93934574e-01],\n",
" [ 2.97978684e-01, 3.60204611e-01, -3.43012578e-01, ...,\n",
" 1.89124800e+00, 3.61207421e-01, -3.93187312e-01],\n",
" [ 2.97978684e-01, 3.56380013e-01, -3.42411133e-01, ...,\n",
" 1.89113962e+00, 4.17659764e-01, -4.00659940e-01],\n",
" ...,\n",
" [ 2.67723095e-01, 3.79327599e-01, -3.41208422e-01, ...,\n",
" 1.89341574e+00, 4.17659764e-01, -4.26066879e-01],\n",
" [ 2.66210091e-01, 3.67853806e-01, -3.37600468e-01, ...,\n",
" 1.89341574e+00, 4.07843030e-01, -4.23451459e-01],\n",
" [ 2.58646870e-01, 3.67853806e-01, -3.32789623e-01, ...,\n",
" 1.89384928e+00, 4.12739156e-01, -4.23077827e-01]],\n",
"\n",
" [[ 2.82850439e-01, 3.37257594e-01, -3.54438065e-01, ...,\n",
" 1.88918866e+00, 4.21331859e-01, -4.62682760e-01],\n",
" [ 2.97978684e-01, 3.23871787e-01, -3.36397757e-01, ...,\n",
" 1.88951382e+00, 4.20107827e-01, -4.58946446e-01],\n",
" [ 2.97978684e-01, 3.41082192e-01, -3.40607156e-01, ...,\n",
" 1.89027252e+00, 3.42797984e-01, -4.53715605e-01],\n",
" ...,\n",
" [ 3.13106027e-01, 3.75503002e-01, -2.82276652e-01, ...,\n",
" 1.89428283e+00, 4.22580371e-01, -3.69274899e-01],\n",
" [ 3.05541905e-01, 3.71678404e-01, -3.02722382e-01, ...,\n",
" 1.89471638e+00, 4.26252467e-01, -3.61055007e-01],\n",
" [ 3.10080018e-01, 3.52555985e-01, -3.26174802e-01, ...,\n",
" 1.89558347e+00, 4.20107827e-01, -4.29055930e-01]],\n",
"\n",
" [[ 3.78154416e-01, 3.41082192e-01, -2.14324638e-01, ...,\n",
" 1.89731766e+00, 4.15187220e-01, -4.96309590e-01],\n",
" [ 3.79667421e-01, 3.60204611e-01, -2.04703128e-01, ...,\n",
" 1.89731766e+00, 4.15187220e-01, -4.59693709e-01],\n",
" [ 3.73616303e-01, 3.75503002e-01, -1.84257397e-01, ...,\n",
" 1.89710088e+00, 4.21331859e-01, -4.56704657e-01],\n",
" ...,\n",
" [ 3.10080018e-01, 3.48731388e-01, -3.35796311e-01, ...,\n",
" 1.89178994e+00, 4.15187220e-01, -3.42373435e-01],\n",
" [ 3.01003792e-01, 3.35345580e-01, -3.37600468e-01, ...,\n",
" 1.89200671e+00, 4.11515125e-01, -3.44241592e-01],\n",
" [ 3.01003792e-01, 3.46818804e-01, -3.30384201e-01, ...,\n",
" 1.89254864e+00, 3.62431453e-01, -3.44615224e-01]],\n",
"\n",
" [[ 2.66210091e-01, 3.67853806e-01, -3.37600468e-01, ...,\n",
" 1.89341574e+00, 4.07843030e-01, -4.23451459e-01],\n",
" [ 2.58646870e-01, 3.67853806e-01, -3.32789623e-01, ...,\n",
" 1.89384928e+00, 4.12739156e-01, -4.23077827e-01],\n",
" [ 2.69236100e-01, 3.46818804e-01, -3.29181490e-01, ...,\n",
" 1.89309058e+00, 4.12739156e-01, -5.36661786e-01],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[-1.34335927e+00, 2.22521372e-01, -1.34004308e+00, ...,\n",
" 1.46420443e+00, 1.65215463e-01, 2.04622998e-01],\n",
" [-1.34941039e+00, 2.07223551e-01, -1.32741488e+00, ...,\n",
" 1.46463798e+00, 1.75032197e-01, 2.05370261e-01],\n",
" [-1.34487228e+00, 1.91925160e-01, -1.35026604e+00, ...,\n",
" 1.46420443e+00, 1.62816361e-01, 2.04622998e-01],\n",
" ...,\n",
" [-1.34033417e+00, 1.84276535e-01, -1.35928610e+00, ...,\n",
" 1.45575027e+00, 1.75032197e-01, 4.22823761e-01],\n",
" [-1.34033417e+00, 2.07223551e-01, -1.34725935e+00, ...,\n",
" 1.45553350e+00, 1.75032197e-01, 4.22823761e-01],\n",
" [-1.33882116e+00, 2.16784761e-01, -1.35267146e+00, ...,\n",
" 1.45575027e+00, 1.72584134e-01, 4.30670022e-01]],\n",
"\n",
" [[-1.33277004e+00, 2.20609358e-01, -1.34425248e+00, ...,\n",
" 1.46095283e+00, 1.26119890e-01, 2.51326929e-01],\n",
" [-1.33277004e+00, 1.88100563e-01, -1.36349568e+00, ...,\n",
" 1.46095283e+00, 1.46928429e-01, 2.51326929e-01],\n",
" [-1.33579605e+00, 2.16784761e-01, -1.35447544e+00, ...,\n",
" 1.46073606e+00, 1.59144266e-01, 2.51700560e-01],\n",
" ...,\n",
" [-1.35999962e+00, 1.95749758e-01, -1.35146875e+00, ...,\n",
" 1.45444963e+00, 1.62816361e-01, 5.75265390e-01],\n",
" [-1.36453773e+00, 1.91925160e-01, -1.35146875e+00, ...,\n",
" 1.45509995e+00, 1.45679917e-01, 5.81243493e-01],\n",
" [-1.35092340e+00, 1.80451937e-01, -1.34786061e+00, ...,\n",
" 1.45531672e+00, 1.65215463e-01, 5.87221597e-01]],\n",
"\n",
" [[-1.35697451e+00, 2.22521372e-01, -1.33523242e+00, ...,\n",
" 1.45835155e+00, 1.79903844e-01, 3.14470643e-01],\n",
" [-1.34789739e+00, 2.26345970e-01, -1.35447544e+00, ...,\n",
" 1.45835155e+00, 1.72584134e-01, 3.21196009e-01],\n",
" [-1.34033417e+00, 2.01486940e-01, -1.35928610e+00, ...,\n",
" 1.45845994e+00, 1.66463975e-01, 3.24185060e-01],\n",
" ...,\n",
" [-1.29646424e+00, 1.86188549e-01, -1.32260404e+00, ...,\n",
" 1.46073606e+00, 1.70136070e-01, 7.59465692e-01],\n",
" [-1.29948935e+00, 1.65153546e-01, -1.33823910e+00, ...,\n",
" 1.46116960e+00, 1.55472170e-01, 7.63202007e-01],\n",
" [-1.31612969e+00, 1.91925160e-01, -1.34425248e+00, ...,\n",
" 1.46095283e+00, 1.55472170e-01, 7.61707481e-01]],\n",
"\n",
" [[-1.34033417e+00, 2.07223551e-01, -1.34725935e+00, ...,\n",
" 1.45553350e+00, 1.75032197e-01, 4.22823761e-01],\n",
" [-1.33882116e+00, 2.16784761e-01, -1.35267146e+00, ...,\n",
" 1.45575027e+00, 1.72584134e-01, 4.30670022e-01],\n",
" [-1.35394850e+00, 2.03398953e-01, -1.34605664e+00, ...,\n",
" 1.45596704e+00, 1.45679917e-01, 4.21702867e-01],\n",
" ...,\n",
" [-1.39176731e+00, 2.01486940e-01, -1.34725935e+00, ...,\n",
" 1.45444963e+00, 1.75032197e-01, 9.48896834e-01],\n",
" [-1.39328032e+00, 2.01486940e-01, -1.34485393e+00, ...,\n",
" 1.45455802e+00, 1.72584134e-01, 9.60853040e-01],\n",
" [-1.40084444e+00, 2.03398953e-01, -1.34064453e+00, ...,\n",
" 1.45455802e+00, 1.67688007e-01, 9.58984883e-01]],\n",
"\n",
" [[-1.36453773e+00, 1.91925160e-01, -1.35146875e+00, ...,\n",
" 1.45509995e+00, 1.45679917e-01, 5.81243493e-01],\n",
" [-1.35092340e+00, 1.80451937e-01, -1.34786061e+00, ...,\n",
" 1.45531672e+00, 1.65215463e-01, 5.87221597e-01],\n",
" [-1.35697451e+00, 1.72802742e-01, -1.35146875e+00, ...,\n",
" 1.45509995e+00, 1.26119890e-01, 5.89463385e-01],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],\n",
"\n",
"\n",
" [[[ 7.51804399e-01, 8.34446750e-01, 9.02373921e-01, ...,\n",
" -5.44851320e-01, 3.19516901e-01, 8.35664105e-02],\n",
" [ 7.44241178e-01, 8.38271918e-01, 9.11995252e-01, ...,\n",
" -5.45393254e-01, 3.04779559e-01, 9.10390394e-02],\n",
" [ 7.51804399e-01, 8.53569169e-01, 9.11995252e-01, ...,\n",
" -5.43767453e-01, 3.07252103e-01, 6.67529955e-02],\n",
" ...,\n",
" [ 6.80704892e-01, 1.02567322e+00, 9.45670805e-01, ...,\n",
" -5.44634547e-01, 3.15820325e-01, -2.07866116e-01],\n",
" [ 6.80704892e-01, 1.02949839e+00, 9.45670805e-01, ...,\n",
" -5.43767453e-01, 3.29284674e-01, -1.49579610e-01],\n",
" [ 6.76166779e-01, 9.96989019e-01, 9.27630317e-01, ...,\n",
" -5.75416368e-01, 3.29284674e-01, -1.47337822e-01]],\n",
"\n",
" [[ 7.42728173e-01, 9.14761593e-01, 9.49278759e-01, ...,\n",
" -5.37697798e-01, 2.47201107e-01, -4.27210174e-02],\n",
" [ 7.45753281e-01, 9.10937565e-01, 9.34245139e-01, ...,\n",
" -5.46043574e-01, 3.09675686e-01, 1.29500677e-02],\n",
" [ 7.33651045e-01, 9.07113537e-01, 9.42062493e-01, ...,\n",
" -5.44851320e-01, 3.09675686e-01, 1.14555419e-02],\n",
" ...,\n",
" [ 7.01883354e-01, 1.03714758e+00, 9.80548532e-01, ...,\n",
" -5.41708106e-01, 3.32981250e-01, -1.70876603e-01],\n",
" [ 7.06421466e-01, 1.02949839e+00, 9.72129733e-01, ...,\n",
" -5.40407466e-01, 3.21964964e-01, -1.70129340e-01],\n",
" [ 7.12472584e-01, 1.04097161e+00, 9.73933710e-01, ...,\n",
" -5.33145557e-01, 3.02355976e-01, -1.70129340e-01]],\n",
"\n",
" [[ 7.36677055e-01, 9.95077575e-01, 9.51684182e-01, ...,\n",
" -5.42900360e-01, 3.24413028e-01, -7.33587958e-02],\n",
" [ 7.27600829e-01, 9.68305961e-01, 9.58299003e-01, ...,\n",
" -5.44092613e-01, 3.98001815e-01, -7.44796902e-02],\n",
" [ 7.10960481e-01, 9.72129989e-01, 9.61305512e-01, ...,\n",
" -5.41816493e-01, 3.31757218e-01, -7.41060587e-02],\n",
" ...,\n",
" [ 7.36677055e-01, 1.00272677e+00, 9.61906957e-01, ...,\n",
" -5.44742933e-01, 3.26861091e-01, -1.64524868e-01],\n",
" [ 7.35164050e-01, 9.98901603e-01, 9.54690691e-01, ...,\n",
" -5.43333906e-01, 2.98659401e-01, -1.61535817e-01],\n",
" [ 7.36677055e-01, 9.98901603e-01, 9.60704426e-01, ...,\n",
" -5.42683586e-01, 3.23164516e-01, -1.60788554e-01]],\n",
"\n",
" [[ 6.80704892e-01, 1.02949839e+00, 9.45670805e-01, ...,\n",
" -5.43767453e-01, 3.29284674e-01, -1.49579610e-01],\n",
" [ 6.76166779e-01, 9.96989019e-01, 9.27630317e-01, ...,\n",
" -5.75416368e-01, 3.29284674e-01, -1.47337822e-01],\n",
" [ 6.82217897e-01, 1.01037597e+00, 9.37251648e-01, ...,\n",
" -5.43442293e-01, 2.85170571e-01, -1.49579610e-01],\n",
" ...,\n",
" [ 7.03396358e-01, 1.00655080e+00, 9.77542023e-01, ...,\n",
" -5.40299079e-01, 3.25637060e-01, -2.00019855e-01],\n",
" [ 7.26087824e-01, 9.96989019e-01, 9.84156845e-01, ...,\n",
" -5.45826800e-01, 3.17068838e-01, -1.97778067e-01],\n",
" [ 7.21549712e-01, 9.98901603e-01, 9.85359377e-01, ...,\n",
" -5.46043574e-01, 3.32981250e-01, -1.99272592e-01]],\n",
"\n",
" [[ 7.06421466e-01, 1.02949839e+00, 9.72129733e-01, ...,\n",
" -5.40407466e-01, 3.21964964e-01, -1.70129340e-01],\n",
" [ 7.12472584e-01, 1.04097161e+00, 9.73933710e-01, ...,\n",
" -5.33145557e-01, 3.02355976e-01, -1.70129340e-01],\n",
" [ 7.12472584e-01, 1.04097161e+00, 9.56495026e-01, ...,\n",
" -5.44959707e-01, 3.13372262e-01, -1.65272131e-01],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],\n",
"\n",
"\n",
" [[[ 1.04376564e+00, 1.32398728e+00, -4.93348782e-01, ...,\n",
" 1.24201171e+00, -2.46573287e-01, -1.40572852e+00],\n",
" [ 1.14360773e+00, 1.28765389e+00, -4.77112450e-01, ...,\n",
" 1.24331235e+00, -2.35654924e-01, -1.41544294e+00],\n",
" [ 1.24496193e+00, 1.23984727e+00, -4.57869430e-01, ...,\n",
" 1.24331235e+00, -2.32007309e-01, -1.41768473e+00],\n",
" ...,\n",
" [ 3.43360714e-01, -9.86030993e-01, 1.62117675e-01, ...,\n",
" 1.21870857e+00, -2.28384176e-01, -1.19836307e+00],\n",
" [ 3.43360714e-01, -9.38224377e-01, 1.65124363e-01, ...,\n",
" 1.22033437e+00, -2.51444933e-01, -1.19201134e+00],\n",
" [ 3.49411832e-01, -9.13364776e-01, 1.59110807e-01, ...,\n",
" 1.22488661e+00, -2.34430892e-01, -1.18453871e+00]],\n",
"\n",
" [[-6.61107696e-01, -6.99190723e-01, 3.32298343e-01, ...,\n",
" 1.23073949e+00, -2.17441332e-01, -1.86716336e+00],\n",
" [-7.06490629e-01, -7.46997340e-01, 3.34703765e-01, ...,\n",
" 1.23073949e+00, -2.19840434e-01, -1.87127330e+00],\n",
" [-7.18591962e-01, -7.98628554e-01, 3.22676833e-01, ...,\n",
" 1.23052272e+00, -2.22288498e-01, -1.87127330e+00],\n",
" ...,\n",
" [ 4.34126579e-01, -3.70281017e-01, 3.55625774e-04, ...,\n",
" 1.23355755e+00, -2.35654924e-01, -5.60574198e-01],\n",
" [ 5.06738188e-01, -3.51158029e-01, 1.47879802e-02, ...,\n",
" 1.23355755e+00, -2.23512529e-01, -5.60574198e-01],\n",
" [ 5.73299581e-01, -3.24386415e-01, 2.08013566e-02, ...,\n",
" 1.23377432e+00, -2.18640883e-01, -5.54596095e-01]],\n",
"\n",
" [[ 4.52279030e-01, -7.38795377e-02, 3.88824870e-01, ...,\n",
" 1.21968405e+00, -2.23512529e-01, -1.82793206e+00],\n",
" [ 4.12947216e-01, -9.87391379e-02, 3.99649091e-01, ...,\n",
" 1.21827502e+00, -2.27135663e-01, -1.82158032e+00],\n",
" [ 3.46386724e-01, -1.31247934e-01, 3.98446380e-01, ...,\n",
" 1.21914212e+00, -2.19840434e-01, -1.81896490e+00],\n",
" ...,\n",
" [ 9.36360333e-01, 1.26908139e-01, 6.08315799e-01, ...,\n",
" 1.23789301e+00, -2.36854475e-01, -3.55076904e-01],\n",
" [ 7.45753281e-01, 9.05753151e-02, 6.40187196e-01, ...,\n",
" 1.23767624e+00, -2.25936112e-01, -3.67406742e-01],\n",
" [ 5.33967767e-01, 3.89441007e-02, 6.90098901e-01, ...,\n",
" 1.23756785e+00, -2.19840434e-01, -3.53582379e-01]],\n",
"\n",
" [[ 3.43360714e-01, -9.38224377e-01, 1.65124363e-01, ...,\n",
" 1.22033437e+00, -2.51444933e-01, -1.19201134e+00],\n",
" [ 3.49411832e-01, -9.13364776e-01, 1.59110807e-01, ...,\n",
" 1.22488661e+00, -2.34430892e-01, -1.18453871e+00],\n",
" [ 3.75128406e-01, -9.01890983e-01, 1.75347139e-01, ...,\n",
" 1.22065953e+00, -2.23512529e-01, -1.18042876e+00],\n",
" ...,\n",
" [ 1.72904323e+00, 5.01713017e-01, 1.43475921e-01, ...,\n",
" 1.24797298e+00, -2.17441332e-01, -4.38023085e-01],\n",
" [ 1.73811946e+00, 5.32308659e-01, 1.42273210e-01, ...,\n",
" 1.24873168e+00, -2.25936112e-01, -4.38770348e-01],\n",
" [ 1.71089078e+00, 5.55256245e-01, 1.37462365e-01, ...,\n",
" 1.24851491e+00, -2.30783278e-01, -4.41759399e-01]],\n",
"\n",
" [[ 5.06738188e-01, -3.51158029e-01, 1.47879802e-02, ...,\n",
" 1.23355755e+00, -2.23512529e-01, -5.60574198e-01],\n",
" [ 5.73299581e-01, -3.24386415e-01, 2.08013566e-02, ...,\n",
" 1.23377432e+00, -2.18640883e-01, -5.54596095e-01],\n",
" [ 6.24733631e-01, -3.03351412e-01, 3.16255776e-02, ...,\n",
" 1.23377432e+00, -2.29583727e-01, -5.49738886e-01],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_test[:64]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"ExecuteTime": {
"end_time": "2018-08-19T11:48:38.811810Z",
"start_time": "2018-08-19T11:48:33.296990Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4min 53s, sys: 23.4 s, total: 5min 16s\n",
"Wall time: 1min 2s\n"
]
}
],
"source": [
"%%time\n",
"predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
" minProb=0.99, keep_prob=1.0)\n",
"bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy at k = 1: 0.854771\n"
]
}
],
"source": [
"print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}