--- a +++ b/SWELL-KW/SWELL-KW_LSTM.ipynb @@ -0,0 +1,1493 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SWELL-KW LSTM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-15T18:30:17.522073Z", + "start_time": "2019-07-15T18:30:17.217355Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from tabulate import tabulate\n", + "import os\n", + "import datetime as datetime\n", + "import pickle as pkl\n", + "import pathlib" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DataFrames from CSVs" + ] + }, + { + "cell_type": "code", + "execution_count": 224, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-15T18:30:28.212447Z", + "start_time": "2019-07-15T18:30:17.889941Z" + } + }, + "outputs": [], + "source": [ + "df = pd.read_csv('/home/sf/data/SWELL-KW/EDA_New.csv', index_col=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preprocessing " + ] + }, + { + "cell_type": "code", + "execution_count": 226, + "metadata": {}, + "outputs": [], + "source": [ + "df['condition'].replace('time pressure', 'interruption', inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 227, + "metadata": {}, + "outputs": [], + "source": [ + "df['condition'].replace('interruption', 1, inplace=True)\n", + "df['condition'].replace('no stress', 0, inplace=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split Ground Truth " + ] + }, + { + "cell_type": "code", + "execution_count": 188, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.027338218248412812" + ] + }, + "execution_count": 188, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df[df['condition']==1]['RANGE'].mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 229, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-15T18:30:37.760226Z", + "start_time": "2019-07-15T18:30:37.597908Z" + } + }, + "outputs": [], + "source": [ + "filtered_train = df.drop(['condition'], axis=1)\n", + "filtered_target = df['condition']" + ] + }, + { + "cell_type": "code", + "execution_count": 231, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-15T18:32:27.581110Z", + "start_time": "2019-07-15T18:32:27.576162Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(104040,)\n", + "(104040, 22)\n" + ] + } + ], + "source": [ + "print(filtered_target.shape)\n", + "print(filtered_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 232, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 52695\n", + "1 51345\n", + "Name: condition, dtype: int64" + ] + }, + "execution_count": 232, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.value_counts(filtered_target)" + ] + }, + { + "cell_type": "code", + "execution_count": 233, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:24:41.449331Z", + "start_time": "2019-06-19T05:24:41.445741Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "y = filtered_target.values.reshape(-1, 20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert to 3D - (Bags, Timesteps, Features)" + ] + }, + { + "cell_type": "code", + "execution_count": 234, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "22" + ] + }, + "execution_count": 234, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(filtered_train.columns)" + ] + }, + { + "cell_type": "code", + "execution_count": 235, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:24:43.963421Z", + "start_time": "2019-06-19T05:24:43.944207Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(104040, 22)\n", + "(5202, 20, 22)\n" + ] + } + ], + "source": [ + "x = filtered_train.values\n", + "print(x.shape)\n", + "x = x.reshape(int(len(x) / 20), 20, 22)\n", + "print(x.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filter Overlapping Bags" + ] + }, + { + "cell_type": "code", + "execution_count": 236, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:19.299273Z", + "start_time": "2019-06-19T05:25:18.952971Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[120, 198, 289, 408, 410, 415, 544, 667, 776, 894, 908, 911, 913, 1019, 1103, 1117, 1147, 1150, 1191, 1305, 1440, 1549, 1643, 1660, 1685, 1805, 1937, 2042, 2175, 2179, 2183, 2184, 2384, 2490, 2566, 2571, 2594, 2597, 2637, 2638, 2761, 2810, 2918, 2990, 3004, 3008, 3015, 3056, 3164, 3236, 3244, 3274, 3277, 3318, 3433, 3546, 3556, 3560, 3562, 3563, 3564, 3676, 3786, 3796, 3800, 3839, 3840, 3841, 3956, 4051, 4081, 4084, 4125, 4126, 4232, 4310, 4411, 4527, 4544, 4546, 4549, 4554, 4559, 4564, 4687, 4789, 4895, 4968, 4985, 4989, 5008, 5131]\n" + ] + } + ], + "source": [ + "# filtering bags that overlap with another class\n", + "bags_to_remove = []\n", + "for i in range(len(y)):\n", + " if len(set(y[i])) > 1:\n", + " bags_to_remove.append(i)\n", + "print(bags_to_remove)" + ] + }, + { + "cell_type": "code", + "execution_count": 237, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:26.446153Z", + "start_time": "2019-06-19T05:25:26.256245Z" + } + }, + "outputs": [], + "source": [ + "x = np.delete(x, bags_to_remove, axis=0)\n", + "y = np.delete(y, bags_to_remove, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 238, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:27.260726Z", + "start_time": "2019-06-19T05:25:27.254474Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(5110, 20, 22)" + ] + }, + "execution_count": 238, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.shape " + ] + }, + { + "cell_type": "code", + "execution_count": 239, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:28.101475Z", + "start_time": "2019-06-19T05:25:28.096009Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(5110, 20)" + ] + }, + "execution_count": 239, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Categorical Representation " + ] + }, + { + "cell_type": "code", + "execution_count": 240, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:50.094089Z", + "start_time": "2019-06-19T05:25:49.746284Z" + } + }, + "outputs": [], + "source": [ + "one_hot_list = []\n", + "for i in range(len(y)):\n", + " one_hot_list.append(set(y[i]).pop())" + ] + }, + { + "cell_type": "code", + "execution_count": 241, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:52.203633Z", + "start_time": "2019-06-19T05:25:52.198467Z" + } + }, + "outputs": [], + "source": [ + "categorical_y_ver = one_hot_list\n", + "categorical_y_ver = np.array(categorical_y_ver)" + ] + }, + { + "cell_type": "code", + "execution_count": 242, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:53.132006Z", + "start_time": "2019-06-19T05:25:53.126314Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(5110,)" + ] + }, + "execution_count": 242, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "categorical_y_ver.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 243, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:25:54.495163Z", + "start_time": "2019-06-19T05:25:54.489349Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "20" + ] + }, + "execution_count": 243, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 244, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:26:08.021392Z", + "start_time": "2019-06-19T05:26:08.017038Z" + } + }, + "outputs": [], + "source": [ + "def one_hot(y, numOutput):\n", + " y = np.reshape(y, [-1])\n", + " ret = np.zeros([y.shape[0], numOutput])\n", + " for i, label in enumerate(y):\n", + " ret[i, label] = 1\n", + " return ret" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Extract 3D Normalized Data with Validation Set" + ] + }, + { + "cell_type": "code", + "execution_count": 245, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:26:08.931435Z", + "start_time": "2019-06-19T05:26:08.927397Z" + } + }, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "import pathlib" + ] + }, + { + "cell_type": "code", + "execution_count": 246, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:26:10.295822Z", + "start_time": "2019-06-19T05:26:09.832723Z" + } + }, + "outputs": [], + "source": [ + "x_train_val_combined, x_test, y_train_val_combined, y_test = train_test_split(x, categorical_y_ver, test_size=0.20, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 247, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:26:11.260084Z", + "start_time": "2019-06-19T05:26:11.253714Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 1, 0, ..., 1, 0, 1])" + ] + }, + "execution_count": 247, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_test" + ] + }, + { + "cell_type": "code", + "execution_count": 248, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-18T17:24:41.832337Z", + "start_time": "2019-06-18T17:24:41.141678Z" + } + }, + "outputs": [], + "source": [ + "extractedDir = '/home/sf/data/SWELL-KW/'\n", + "# def generateData(extractedDir):\n", + "# x_train_val_combined, x_test, y_train_val_combined, y_test = readData(extractedDir)\n", + "timesteps = x_train_val_combined.shape[-2] #128/256\n", + "feats = x_train_val_combined.shape[-1] #16\n", + "\n", + "trainSize = int(x_train_val_combined.shape[0]*0.9) #6566\n", + "x_train, x_val = x_train_val_combined[:trainSize], x_train_val_combined[trainSize:] \n", + "y_train, y_val = y_train_val_combined[:trainSize], y_train_val_combined[trainSize:]\n", + "\n", + "# normalization\n", + "x_train = np.reshape(x_train, [-1, feats])\n", + "mean = np.mean(x_train, axis=0)\n", + "std = np.std(x_train, axis=0)\n", + "\n", + "# normalize train\n", + "x_train = x_train - mean\n", + "x_train = x_train / std\n", + "x_train = np.reshape(x_train, [-1, timesteps, feats])\n", + "\n", + "# normalize val\n", + "x_val = np.reshape(x_val, [-1, feats])\n", + "x_val = x_val - mean\n", + "x_val = x_val / std\n", + "x_val = np.reshape(x_val, [-1, timesteps, feats])\n", + "\n", + "# normalize test\n", + "x_test = np.reshape(x_test, [-1, feats])\n", + "x_test = x_test - mean\n", + "x_test = x_test / std\n", + "x_test = np.reshape(x_test, [-1, timesteps, feats])\n", + "\n", + "# shuffle test, as this was remaining\n", + "idx = np.arange(len(x_test))\n", + "np.random.shuffle(idx)\n", + "x_test = x_test[idx]\n", + "y_test = y_test[idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 249, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-18T17:25:50.674962Z", + "start_time": "2019-06-18T17:25:50.481068Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/sf/data/SWELL-KW//\n" + ] + } + ], + "source": [ + "# one-hot encoding of labels\n", + "numOutput = 2\n", + "y_train = one_hot(y_train, numOutput)\n", + "y_val = one_hot(y_val, numOutput)\n", + "y_test = one_hot(y_test, numOutput)\n", + "extractedDir += '/'\n", + "\n", + "pathlib.Path(extractedDir + 'RAW').mkdir(parents=True, exist_ok = True)\n", + "\n", + "np.save(extractedDir + \"RAW/x_train\", x_train)\n", + "np.save(extractedDir + \"RAW/y_train\", y_train)\n", + "np.save(extractedDir + \"RAW/x_test\", x_test)\n", + "np.save(extractedDir + \"RAW/y_test\", y_test)\n", + "np.save(extractedDir + \"RAW/x_val\", x_val)\n", + "np.save(extractedDir + \"RAW/y_val\", y_val)\n", + "\n", + "print(extractedDir)" + ] + }, + { + "cell_type": "code", + "execution_count": 250, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-18T17:26:35.381650Z", + "start_time": "2019-06-18T17:26:35.136645Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_test.npy x_train.npy x_val.npy y_test.npy y_train.npy y_val.npy\r\n" + ] + } + ], + "source": [ + "ls /home/sf/data/SWELL-KW/RAW" + ] + }, + { + "cell_type": "code", + "execution_count": 251, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-18T17:27:34.458130Z", + "start_time": "2019-06-18T17:27:34.323712Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(3679, 20, 22)" + ] + }, + "execution_count": 251, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.load('//home/sf/data/SWELL-KW/RAW/x_train.npy').shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make 4D EMI Data (Bags, Subinstances, Subinstance Length, Features)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:30:53.720179Z", + "start_time": "2019-06-19T05:30:53.713756Z" + } + }, + "outputs": [], + "source": [ + "def loadData(dirname):\n", + " x_train = np.load(dirname + '/' + 'x_train.npy')\n", + " y_train = np.load(dirname + '/' + 'y_train.npy')\n", + " x_test = np.load(dirname + '/' + 'x_test.npy')\n", + " y_test = np.load(dirname + '/' + 'y_test.npy')\n", + " x_val = np.load(dirname + '/' + 'x_val.npy')\n", + " y_val = np.load(dirname + '/' + 'y_val.npy')\n", + " return x_train, y_train, x_test, y_test, x_val, y_val" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-18T17:29:23.614052Z", + "start_time": "2019-06-18T17:29:23.601324Z" + } + }, + "outputs": [], + "source": [ + "def bagData(X, Y, subinstanceLen, subinstanceStride):\n", + " numClass = 2\n", + " numSteps = 20\n", + " numFeats = 22\n", + " assert X.ndim == 3\n", + " assert X.shape[1] == numSteps\n", + " assert X.shape[2] == numFeats\n", + " assert subinstanceLen <= numSteps\n", + " assert subinstanceLen > 0\n", + " assert subinstanceStride <= numSteps\n", + " assert subinstanceStride >= 0\n", + " assert len(X) == len(Y)\n", + " assert Y.ndim == 2\n", + " assert Y.shape[1] == numClass\n", + " x_bagged = []\n", + " y_bagged = []\n", + " for i, point in enumerate(X[:, :, :]):\n", + " instanceList = []\n", + " start = 0\n", + " end = subinstanceLen\n", + " while True:\n", + " x = point[start:end, :]\n", + " if len(x) < subinstanceLen:\n", + " x_ = np.zeros([subinstanceLen, x.shape[1]])\n", + " x_[:len(x), :] = x[:, :]\n", + " x = x_\n", + " instanceList.append(x)\n", + " if end >= numSteps:\n", + " break\n", + " start += subinstanceStride\n", + " end += subinstanceStride\n", + " bag = np.array(instanceList)\n", + " numSubinstance = bag.shape[0]\n", + " label = Y[i]\n", + " label = np.argmax(label)\n", + " labelBag = np.zeros([numSubinstance, numClass])\n", + " labelBag[:, label] = 1\n", + " x_bagged.append(bag)\n", + " label = np.array(labelBag)\n", + " y_bagged.append(label)\n", + " return np.array(x_bagged), np.array(y_bagged)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:33:06.531994Z", + "start_time": "2019-06-19T05:33:06.523884Z" + } + }, + "outputs": [], + "source": [ + "def makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir):\n", + " x_train, y_train, x_test, y_test, x_val, y_val = loadData(sourceDir)\n", + " x, y = bagData(x_train, y_train, subinstanceLen, subinstanceStride)\n", + " np.save(outDir + '/x_train.npy', x)\n", + " np.save(outDir + '/y_train.npy', y)\n", + " print('Num train %d' % len(x))\n", + " x, y = bagData(x_test, y_test, subinstanceLen, subinstanceStride)\n", + " np.save(outDir + '/x_test.npy', x)\n", + " np.save(outDir + '/y_test.npy', y)\n", + " print('Num test %d' % len(x))\n", + " x, y = bagData(x_val, y_val, subinstanceLen, subinstanceStride)\n", + " np.save(outDir + '/x_val.npy', x)\n", + " np.save(outDir + '/y_val.npy', y)\n", + " print('Num val %d' % len(x))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:35:20.960014Z", + "start_time": "2019-06-19T05:35:20.050363Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num train 3679\n", + "Num test 1022\n", + "Num val 409\n" + ] + } + ], + "source": [ + "subinstanceLen = 8\n", + "subinstanceStride = 3\n", + "extractedDir = '/home/sf/data/SWELL-KW'\n", + "from os import mkdir\n", + "#mkdir('/home/sf/data/SWELL-KW/8_3')\n", + "rawDir = extractedDir + '/RAW'\n", + "sourceDir = rawDir\n", + "outDir = extractedDir + '/%d_%d/' % (subinstanceLen, subinstanceStride)\n", + "makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-18T17:48:58.293843Z", + "start_time": "2019-06-18T17:48:58.285383Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(3679, 5, 2)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.load('/home/sf/data/SWELL-KW/8_3/y_train.npy').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:35:48.609552Z", + "start_time": "2019-06-19T05:35:48.604291Z" + } + }, + "outputs": [], + "source": [ + "from edgeml.graph.rnn import EMI_DataPipeline\n", + "from edgeml.graph.rnn import EMI_BasicLSTM, EMI_FastGRNN, EMI_FastRNN, EMI_GRU\n", + "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n", + "import edgeml.utils" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:43:48.032413Z", + "start_time": "2019-06-19T05:43:47.987839Z" + } + }, + "outputs": [], + "source": [ + "def lstm_experiment_generator(params, path = './DSAAR/64_16/'):\n", + " \"\"\"\n", + " Function that will generate the experiments to be run.\n", + " Inputs : \n", + " (1) Dictionary params, to set the network parameters.\n", + " (2) Name of the Model to be run from [EMI-LSTM, EMI-FastGRNN, EMI-GRU]\n", + " (3) Path to the dataset, where the csv files are present.\n", + " \"\"\"\n", + " \n", + " #Copy the contents of the params dictionary.\n", + " lstm_dict = {**params}\n", + " \n", + " #---------------------------PARAM SETTING----------------------#\n", + " \n", + " # Network parameters for our LSTM + FC Layer\n", + " NUM_HIDDEN = params[\"NUM_HIDDEN\"]\n", + " NUM_TIMESTEPS = params[\"NUM_TIMESTEPS\"]\n", + " ORIGINAL_NUM_TIMESTEPS = params[\"ORIGINAL_NUM_TIMESTEPS\"]\n", + " NUM_FEATS = params[\"NUM_FEATS\"]\n", + " FORGET_BIAS = params[\"FORGET_BIAS\"]\n", + " NUM_OUTPUT = params[\"NUM_OUTPUT\"]\n", + " USE_DROPOUT = True if (params[\"USE_DROPOUT\"] == 1) else False\n", + " KEEP_PROB = params[\"KEEP_PROB\"]\n", + "\n", + " # For dataset API\n", + " PREFETCH_NUM = params[\"PREFETCH_NUM\"]\n", + " BATCH_SIZE = params[\"BATCH_SIZE\"]\n", + "\n", + " # Number of epochs in *one iteration*\n", + " NUM_EPOCHS = params[\"NUM_EPOCHS\"]\n", + " # Number of iterations in *one round*. After each iteration,\n", + " # the model is dumped to disk. At the end of the current\n", + " # round, the best model among all the dumped models in the\n", + " # current round is picked up..\n", + " NUM_ITER = params[\"NUM_ITER\"]\n", + " # A round consists of multiple training iterations and a belief\n", + " # update step using the best model from all of these iterations\n", + " NUM_ROUNDS = params[\"NUM_ROUNDS\"]\n", + " LEARNING_RATE = params[\"LEARNING_RATE\"]\n", + "\n", + " # A staging direcory to store models\n", + " MODEL_PREFIX = params[\"MODEL_PREFIX\"]\n", + " \n", + " #----------------------END OF PARAM SETTING----------------------#\n", + " \n", + " #----------------------DATA LOADING------------------------------#\n", + " \n", + " x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n", + " x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n", + " x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n", + "\n", + " # BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n", + " # step of EMI/MI RNN\n", + " BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n", + " BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n", + " BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n", + " NUM_SUBINSTANCE = x_train.shape[1]\n", + " print(\"x_train shape is:\", x_train.shape)\n", + " print(\"y_train shape is:\", y_train.shape)\n", + " print(\"x_test shape is:\", x_val.shape)\n", + " print(\"y_test shape is:\", y_val.shape)\n", + " \n", + " #----------------------END OF DATA LOADING------------------------------# \n", + " \n", + " #----------------------COMPUTATION GRAPH--------------------------------#\n", + " \n", + " # Define the linear secondary classifier\n", + " def createExtendedGraph(self, baseOutput, *args, **kwargs):\n", + " W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n", + " B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n", + " y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n", + " self.output = y_cap\n", + " self.graphCreated = True\n", + "\n", + " def restoreExtendedGraph(self, graph, *args, **kwargs):\n", + " y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n", + " self.output = y_cap\n", + " self.graphCreated = True\n", + "\n", + " def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n", + " if inference is False:\n", + " feedDict = {self._emiGraph.keep_prob: keep_prob}\n", + " else:\n", + " feedDict = {self._emiGraph.keep_prob: 1.0}\n", + " return feedDict\n", + "\n", + " EMI_BasicLSTM._createExtendedGraph = createExtendedGraph\n", + " EMI_BasicLSTM._restoreExtendedGraph = restoreExtendedGraph\n", + "\n", + " if USE_DROPOUT is True:\n", + " EMI_Driver.feedDictFunc = feedDictFunc\n", + " \n", + " inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n", + " emiLSTM = EMI_BasicLSTM(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n", + " forgetBias=FORGET_BIAS, useDropout=USE_DROPOUT)\n", + " emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n", + " stepSize=LEARNING_RATE)\n", + " \n", + " tf.reset_default_graph()\n", + " g1 = tf.Graph() \n", + " with g1.as_default():\n", + " # Obtain the iterators to each batch of the data\n", + " x_batch, y_batch = inputPipeline()\n", + " # Create the forward computation graph based on the iterators\n", + " y_cap = emiLSTM(x_batch)\n", + " # Create loss graphs and training routines\n", + " emiTrainer(y_cap, y_batch)\n", + " \n", + " #------------------------------END OF COMPUTATION GRAPH------------------------------#\n", + " \n", + " #-------------------------------------EMI DRIVER-------------------------------------#\n", + " \n", + " with g1.as_default():\n", + " emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n", + "\n", + " emiDriver.initializeSession(g1)\n", + " y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n", + " y_train=y_train, bag_train=BAG_TRAIN,\n", + " x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n", + " numIter=NUM_ITER, keep_prob=KEEP_PROB,\n", + " numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n", + " numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n", + " fracEMI=0.5, updatePolicy='top-k', k=1)\n", + " \n", + " #-------------------------------END OF EMI DRIVER-------------------------------------#\n", + " \n", + " #-----------------------------------EARLY SAVINGS-------------------------------------#\n", + " \n", + " \"\"\"\n", + " Early Prediction Policy: We make an early prediction based on the predicted classes\n", + " probability. If the predicted class probability > minProb at some step, we make\n", + " a prediction at that step.\n", + " \"\"\"\n", + " def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n", + " assert instanceOut.ndim == 2\n", + " classes = np.argmax(instanceOut, axis=1)\n", + " prob = np.max(instanceOut, axis=1)\n", + " index = np.where(prob >= minProb)[0]\n", + " if len(index) == 0:\n", + " assert (len(instanceOut) - 1) == (len(classes) - 1)\n", + " return classes[-1], len(instanceOut) - 1\n", + " index = index[0]\n", + " return classes[index], index\n", + "\n", + " def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n", + " predictionStep = predictionStep + 1\n", + " predictionStep = np.reshape(predictionStep, -1)\n", + " totalSteps = np.sum(predictionStep)\n", + " maxSteps = len(predictionStep) * numTimeSteps\n", + " savings = 1.0 - (totalSteps / maxSteps)\n", + " if returnTotal:\n", + " return savings, totalSteps\n", + " return savings\n", + " \n", + " #--------------------------------END OF EARLY SAVINGS---------------------------------#\n", + " \n", + " #----------------------------------------BEST MODEL-----------------------------------#\n", + " \n", + " k = 2\n", + " predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n", + " minProb=0.99, keep_prob=1.0)\n", + " bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n", + " print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))\n", + " mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n", + " emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n", + " total_savings = mi_savings + (1 - mi_savings) * emi_savings\n", + " print('Savings due to MI-RNN : %f' % mi_savings)\n", + " print('Savings due to Early prediction: %f' % emi_savings)\n", + " print('Total Savings: %f' % (total_savings))\n", + " \n", + " #Store in the dictionary.\n", + " lstm_dict[\"k\"] = k\n", + " lstm_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n", + " lstm_dict[\"total_savings\"] = total_savings\n", + " lstm_dict[\"y_test\"] = BAG_TEST\n", + " lstm_dict[\"y_pred\"] = bagPredictions\n", + " \n", + " # A slightly more detailed analysis method is provided. \n", + " df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n", + " print (tabulate(df, headers=list(df.columns), tablefmt='grid'))\n", + " \n", + " lstm_dict[\"detailed analysis\"] = df\n", + " #----------------------------------END OF BEST MODEL-----------------------------------#\n", + " \n", + " #----------------------------------PICKING THE BEST MODEL------------------------------#\n", + " \n", + " devnull = open(os.devnull, 'r')\n", + " for val in modelStats:\n", + " round_, acc, modelPrefix, globalStep = val\n", + " emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n", + " predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n", + " minProb=0.99, keep_prob=1.0)\n", + "\n", + " bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n", + " print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n", + " print(', Test Accuracy (k = %d): %f, ' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n", + " print('Additional savings: %f' % getEarlySaving(predictionStep, NUM_TIMESTEPS)) \n", + " \n", + " \n", + " #-------------------------------END OF PICKING THE BEST MODEL--------------------------#\n", + "\n", + " return lstm_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:43:51.798834Z", + "start_time": "2019-06-19T05:43:51.792938Z" + } + }, + "outputs": [], + "source": [ + "def experiment_generator(params, path, model = 'lstm'):\n", + " \n", + " \n", + " if (model == 'lstm'): return lstm_experiment_generator(params, path)\n", + " elif (model == 'fastgrnn'): return fastgrnn_experiment_generator(params, path)\n", + " elif (model == 'gru'): return gru_experiment_generator(params, path)\n", + " elif (model == 'baseline'): return baseline_experiment_generator(params, path)\n", + " \n", + " return " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/home/sf/data/EdgeML/tf/examples/EMI-RNN'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pwd" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T05:43:53.168413Z", + "start_time": "2019-06-19T05:43:53.164622Z" + } + }, + "outputs": [], + "source": [ + "import tensorflow as tf" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2019-06-19T08:16:13.770690Z", + "start_time": "2019-06-19T05:45:31.777404Z" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_train shape is: (3679, 5, 8, 22)\n", + "y_train shape is: (3679, 5, 2)\n", + "x_test shape is: (409, 5, 8, 22)\n", + "y_test shape is: (409, 5, 2)\n", + "Update policy: top-k\n", + "Training with MI-RNN loss for 15 rounds\n", + "Round: 0\n", + "Epoch 1 Batch 110 ( 225) Loss 0.08400 Acc 0.59375 | Val acc 0.58924 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1000\n", + "Epoch 1 Batch 110 ( 225) Loss 0.08510 Acc 0.57500 | Val acc 0.63081 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1001\n", + "Epoch 1 Batch 110 ( 225) Loss 0.07742 Acc 0.61250 | Val acc 0.66015 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1002\n", + "Epoch 1 Batch 110 ( 225) Loss 0.07279 Acc 0.67500 | Val acc 0.66504 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1003\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1003\n", + "Round: 1\n", + "Epoch 1 Batch 110 ( 225) Loss 0.06990 Acc 0.71875 | Val acc 0.67971 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1004\n", + "Epoch 1 Batch 110 ( 225) Loss 0.06914 Acc 0.72500 | Val acc 0.68215 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1005\n", + "Epoch 1 Batch 110 ( 225) Loss 0.06405 Acc 0.76250 | Val acc 0.69682 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1006\n", + "Epoch 1 Batch 110 ( 225) Loss 0.06493 Acc 0.73125 | Val acc 0.70660 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1007\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1007\n", + "Round: 2\n", + "Epoch 1 Batch 110 ( 225) Loss 0.06088 Acc 0.76875 | Val acc 0.72127 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1008\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05882 Acc 0.75000 | Val acc 0.73594 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1009\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05941 Acc 0.79375 | Val acc 0.74083 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1010\n", + "Epoch 1 Batch 110 ( 225) Loss 0.06156 Acc 0.76875 | Val acc 0.74817 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1011\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1011\n", + "Round: 3\n", + "Epoch 1 Batch 110 ( 225) Loss 0.06403 Acc 0.76250 | Val acc 0.75550 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1012\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05794 Acc 0.75625 | Val acc 0.75795 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1013\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05490 Acc 0.82500 | Val acc 0.76773 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1014\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05780 Acc 0.81250 | Val acc 0.78240 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1015\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1015\n", + "Round: 4\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05379 Acc 0.83750 | Val acc 0.77995 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1016\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05840 Acc 0.80000 | Val acc 0.78240 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1017\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05547 Acc 0.81250 | Val acc 0.77506 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1018\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05212 Acc 0.81875 | Val acc 0.78973 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1019\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1019\n", + "Round: 5\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04983 Acc 0.84375 | Val acc 0.78973 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1020\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04638 Acc 0.85000 | Val acc 0.80196 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1021\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04734 Acc 0.83750 | Val acc 0.80440 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1022\n", + "Epoch 1 Batch 110 ( 225) Loss 0.05063 Acc 0.85625 | Val acc 0.80440 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1023\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1022\n", + "Round: 6\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04497 Acc 0.86875 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1024\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04768 Acc 0.83125 | Val acc 0.79951 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1025\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04779 Acc 0.81875 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1026\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04416 Acc 0.81250 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1027\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1027\n", + "Round: 7\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04235 Acc 0.84375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1028\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04155 Acc 0.84375 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1029\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03803 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1030\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04139 Acc 0.85000 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1031\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1031\n", + "Round: 8\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04385 Acc 0.81875 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1032\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03797 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1033\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03902 Acc 0.83750 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1034\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03691 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1035\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1032\n", + "Round: 9\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03872 Acc 0.87500 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1036\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04312 Acc 0.82500 | Val acc 0.81907 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1037\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03735 Acc 0.87500 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1038\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04043 Acc 0.84375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1039\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1039\n", + "Round: 10\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03696 Acc 0.85000 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1040\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03923 Acc 0.86250 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1041\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03495 Acc 0.87500 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1042\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03587 Acc 0.88750 | Val acc 0.81907 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1043\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1042\n", + "Round: 11\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04023 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1044\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03841 Acc 0.85625 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1045\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03969 Acc 0.85000 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1046\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 Batch 110 ( 225) Loss 0.03921 Acc 0.85000 | Val acc 0.81418 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1047\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1044\n", + "Round: 12\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04475 Acc 0.81875 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1048\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03966 Acc 0.85625 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1049\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03139 Acc 0.90000 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1050\n", + "Epoch 1 Batch 110 ( 225) Loss 0.04275 Acc 0.81250 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1051\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1050\n", + "Round: 13\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03359 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1052\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03795 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1053\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03523 Acc 0.87500 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1054\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03383 Acc 0.89375 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1055\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1054\n", + "Round: 14\n", + "Epoch 1 Batch 110 ( 225) Loss 0.02968 Acc 0.89375 | Val acc 0.81418 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1056\n", + "Epoch 1 Batch 110 ( 225) Loss 0.03354 Acc 0.85625 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1057\n", + "Epoch 1 Batch 110 ( 225) Loss 0.02782 Acc 0.90000 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1058\n", + "Epoch 1 Batch 110 ( 225) Loss 0.02719 Acc 0.89375 | Val acc 0.82641 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1059\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1058\n", + "Round: 15\n", + "Switching to EMI-Loss function\n", + "Epoch 1 Batch 110 ( 225) Loss 0.44879 Acc 0.84375 | Val acc 0.80929 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1060\n", + "Epoch 1 Batch 110 ( 225) Loss 0.42052 Acc 0.85000 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1061\n", + "Epoch 1 Batch 110 ( 225) Loss 0.37942 Acc 0.87500 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1062\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38046 Acc 0.88125 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1063\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1063\n", + "Round: 16\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38822 Acc 0.86875 | Val acc 0.81663 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1064\n", + "Epoch 1 Batch 110 ( 225) Loss 0.37139 Acc 0.88750 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1065\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38185 Acc 0.88125 | Val acc 0.82152 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1066\n", + "Epoch 1 Batch 110 ( 225) Loss 0.37084 Acc 0.86875 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1067\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1065\n", + "Round: 17\n", + "Epoch 1 Batch 110 ( 225) Loss 0.37820 Acc 0.89375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1068\n", + "Epoch 1 Batch 110 ( 225) Loss 0.36451 Acc 0.84375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1069\n", + "Epoch 1 Batch 110 ( 225) Loss 0.36738 Acc 0.89375 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1070\n", + "Epoch 1 Batch 110 ( 225) Loss 0.40073 Acc 0.88750 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1071\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1068\n", + "Round: 18\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38454 Acc 0.87500 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1072\n", + "Epoch 1 Batch 110 ( 225) Loss 0.36183 Acc 0.90000 | Val acc 0.82396 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1073\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38793 Acc 0.90000 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1074\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38047 Acc 0.83750 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1075\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1072\n", + "Round: 19\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38943 Acc 0.85625 | Val acc 0.82885 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1076\n", + "Epoch 1 Batch 110 ( 225) Loss 0.40676 Acc 0.86250 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1077\n", + "Epoch 1 Batch 110 ( 225) Loss 0.37826 Acc 0.90000 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1078\n", + "Epoch 1 Batch 110 ( 225) Loss 0.39518 Acc 0.86250 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1079\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1078\n", + "Round: 20\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38738 Acc 0.85625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1080\n", + "Epoch 1 Batch 110 ( 225) Loss 0.38468 Acc 0.85625 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1081\n", + "Epoch 1 Batch 110 ( 225) Loss 0.35865 Acc 0.89375 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1082\n", + "Epoch 1 Batch 110 ( 225) Loss 0.33802 Acc 0.92500 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1083\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1083\n", + "Round: 21\n", + "Epoch 1 Batch 110 ( 225) Loss 0.35882 Acc 0.89375 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1084\n", + "Epoch 1 Batch 110 ( 225) Loss 0.37442 Acc 0.85625 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1085\n", + "Epoch 1 Batch 110 ( 225) Loss 0.40794 Acc 0.87500 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1086\n", + "Epoch 1 Batch 110 ( 225) Loss 0.33659 Acc 0.91250 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1087\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1084\n", + "Round: 22\n", + "Epoch 1 Batch 110 ( 225) Loss 0.33654 Acc 0.88125 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1088\n", + "Epoch 1 Batch 110 ( 225) Loss 0.36120 Acc 0.88125 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1089\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34685 Acc 0.88125 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1090\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32916 Acc 0.93125 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1091\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1091\n", + "Round: 23\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34161 Acc 0.90625 | Val acc 0.83130 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1092\n", + "Epoch 1 Batch 110 ( 225) Loss 0.39510 Acc 0.85000 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1093\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 Batch 110 ( 225) Loss 0.37442 Acc 0.91875 | Val acc 0.84841 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1094\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32392 Acc 0.91875 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1095\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1094\n", + "Round: 24\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32550 Acc 0.90625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1096\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34901 Acc 0.90625 | Val acc 0.83619 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1097\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34062 Acc 0.90625 | Val acc 0.84108 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1098\n", + "Epoch 1 Batch 110 ( 225) Loss 0.33095 Acc 0.90625 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1099\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1099\n", + "Round: 25\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34793 Acc 0.90000 | Val acc 0.85819 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1100\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32091 Acc 0.92500 | Val acc 0.86553 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1101\n", + "Epoch 1 Batch 110 ( 225) Loss 0.35738 Acc 0.90625 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1102\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34036 Acc 0.90000 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1103\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1101\n", + "Round: 26\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34306 Acc 0.88750 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1104\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32850 Acc 0.92500 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1105\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34266 Acc 0.88750 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1106\n", + "Epoch 1 Batch 110 ( 225) Loss 0.34165 Acc 0.89375 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1107\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1105\n", + "Round: 27\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32461 Acc 0.89375 | Val acc 0.84597 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1108\n", + "Epoch 1 Batch 110 ( 225) Loss 0.33608 Acc 0.89375 | Val acc 0.84352 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1109\n", + "Epoch 1 Batch 110 ( 225) Loss 0.33137 Acc 0.90625 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1110\n", + "Epoch 1 Batch 110 ( 225) Loss 0.28914 Acc 0.95000 | Val acc 0.83374 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1111\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1110\n", + "Round: 28\n", + "Epoch 1 Batch 110 ( 225) Loss 0.35324 Acc 0.92500 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1112\n", + "Epoch 1 Batch 110 ( 225) Loss 0.31689 Acc 0.91875 | Val acc 0.86797 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1113\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32767 Acc 0.92500 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1114\n", + "Epoch 1 Batch 110 ( 225) Loss 0.30775 Acc 0.91875 | Val acc 0.85330 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1115\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1113\n", + "Round: 29\n", + "Epoch 1 Batch 110 ( 225) Loss 0.33176 Acc 0.91250 | Val acc 0.85086 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1116\n", + "Epoch 1 Batch 110 ( 225) Loss 0.31769 Acc 0.90625 | Val acc 0.83863 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1117\n", + "Epoch 1 Batch 110 ( 225) Loss 0.32659 Acc 0.90625 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1118\n", + "Epoch 1 Batch 110 ( 225) Loss 0.31592 Acc 0.93125 | Val acc 0.85575 | Model saved to /home/sf/data/SWELL-KW/models/model-lstm, global_step 1119\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1118\n", + "Accuracy at k = 2: 0.863992\n", + "Savings due to MI-RNN : 0.600000\n", + "Savings due to Early prediction: 0.208659\n", + "Total Savings: 0.683464\n", + " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n", + "0 1 0.842466 0.842041 0.849946 0.844332 0.842466 0.842466 \n", + "1 2 0.863992 0.863982 0.865217 0.864776 0.863992 0.863992 \n", + "2 3 0.867906 0.867860 0.867807 0.867995 0.867906 0.867906 \n", + "3 4 0.864971 0.864697 0.865674 0.864385 0.864971 0.864971 \n", + "4 5 0.857143 0.856430 0.860480 0.855905 0.857143 0.857143 \n", + "\n", + " micro-rec fscore_01 \n", + "0 0.842466 0.850233 \n", + "1 0.863992 0.865179 \n", + "2 0.867906 0.865404 \n", + "3 0.864971 0.858607 \n", + "4 0.857143 0.846316 \n", + "Max accuracy 0.867906 at subsequencelength 3\n", + "Max micro-f 0.867906 at subsequencelength 3\n", + "Micro-precision 0.867906 at subsequencelength 3\n", + "Micro-recall 0.867906 at subsequencelength 3\n", + "Max macro-f 0.867860 at subsequencelength 3\n", + "macro-precision 0.867807 at subsequencelength 3\n", + "macro-recall 0.867995 at subsequencelength 3\n", + "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n", + "| | len | acc | macro-fsc | macro-pre | macro-rec | micro-fsc | micro-pre | micro-rec | fscore_01 |\n", + "+====+=======+==========+=============+=============+=============+=============+=============+=============+=============+\n", + "| 0 | 1 | 0.842466 | 0.842041 | 0.849946 | 0.844332 | 0.842466 | 0.842466 | 0.842466 | 0.850233 |\n", + "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n", + "| 1 | 2 | 0.863992 | 0.863982 | 0.865217 | 0.864776 | 0.863992 | 0.863992 | 0.863992 | 0.865179 |\n", + "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n", + "| 2 | 3 | 0.867906 | 0.86786 | 0.867807 | 0.867995 | 0.867906 | 0.867906 | 0.867906 | 0.865404 |\n", + "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n", + "| 3 | 4 | 0.864971 | 0.864697 | 0.865674 | 0.864385 | 0.864971 | 0.864971 | 0.864971 | 0.858607 |\n", + "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n", + "| 4 | 5 | 0.857143 | 0.85643 | 0.86048 | 0.855905 | 0.857143 | 0.857143 | 0.857143 | 0.846316 |\n", + "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1003\n", + "Round: 0, Validation accuracy: 0.6650, Test Accuracy (k = 2): 0.684932, Additional savings: 0.006531\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1007\n", + "Round: 1, Validation accuracy: 0.7066, Test Accuracy (k = 2): 0.687867, Additional savings: 0.008072\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1011\n", + "Round: 2, Validation accuracy: 0.7482, Test Accuracy (k = 2): 0.721135, Additional savings: 0.014555\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1015\n", + "Round: 3, Validation accuracy: 0.7824, Test Accuracy (k = 2): 0.745597, Additional savings: 0.020597\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1019\n", + "Round: 4, Validation accuracy: 0.7897, Test Accuracy (k = 2): 0.776908, Additional savings: 0.027740\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1022\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Round: 5, Validation accuracy: 0.8044, Test Accuracy (k = 2): 0.793542, Additional savings: 0.037769\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1027\n", + "Round: 6, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.817025, Additional savings: 0.049315\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1031\n", + "Round: 7, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.826810, Additional savings: 0.054795\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1032\n", + "Round: 8, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.825832, Additional savings: 0.060372\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1039\n", + "Round: 9, Validation accuracy: 0.8264, Test Accuracy (k = 2): 0.821918, Additional savings: 0.068249\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1042\n", + "Round: 10, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.820939, Additional savings: 0.074609\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1044\n", + "Round: 11, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.830724, Additional savings: 0.078914\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1050\n", + "Round: 12, Validation accuracy: 0.8362, Test Accuracy (k = 2): 0.835616, Additional savings: 0.090411\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1054\n", + "Round: 13, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.835616, Additional savings: 0.098973\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1058\n", + "Round: 14, Validation accuracy: 0.8460, Test Accuracy (k = 2): 0.829746, Additional savings: 0.109638\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1063\n", + "Round: 15, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.827789, Additional savings: 0.111913\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1065\n", + "Round: 16, Validation accuracy: 0.8289, Test Accuracy (k = 2): 0.820939, Additional savings: 0.128156\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1068\n", + "Round: 17, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.830724, Additional savings: 0.129525\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1072\n", + "Round: 18, Validation accuracy: 0.8337, Test Accuracy (k = 2): 0.829746, Additional savings: 0.134760\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1078\n", + "Round: 19, Validation accuracy: 0.8386, Test Accuracy (k = 2): 0.837573, Additional savings: 0.143102\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1083\n", + "Round: 20, Validation accuracy: 0.8435, Test Accuracy (k = 2): 0.835616, Additional savings: 0.161717\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1084\n", + "Round: 21, Validation accuracy: 0.8411, Test Accuracy (k = 2): 0.850294, Additional savings: 0.167392\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1091\n", + "Round: 22, Validation accuracy: 0.8509, Test Accuracy (k = 2): 0.846380, Additional savings: 0.177886\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1094\n", + "Round: 23, Validation accuracy: 0.8484, Test Accuracy (k = 2): 0.860078, Additional savings: 0.175612\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1099\n", + "Round: 24, Validation accuracy: 0.8460, Test Accuracy (k = 2): 0.851272, Additional savings: 0.189750\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1101\n", + "Round: 25, Validation accuracy: 0.8655, Test Accuracy (k = 2): 0.854207, Additional savings: 0.193395\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1105\n", + "Round: 26, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.853229, Additional savings: 0.197114\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1110\n", + "Round: 27, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.851272, Additional savings: 0.199731\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1113\n", + "Round: 28, Validation accuracy: 0.8680, Test Accuracy (k = 2): 0.860078, Additional savings: 0.208586\n", + "INFO:tensorflow:Restoring parameters from /home/sf/data/SWELL-KW/models/model-lstm-1118\n", + "Round: 29, Validation accuracy: 0.8557, Test Accuracy (k = 2): 0.863992, Additional savings: 0.208659\n", + "Results for this run have been saved at lstm .\n" + ] + } + ], + "source": [ + "## Baseline EMI-LSTM\n", + "\n", + "dataset = 'SWELL-KW'\n", + "path = '/home/sf/data/SWELL-KW/8_3/'\n", + "\n", + "#Choose model from among [lstm, fastgrnn, gru]\n", + "model = 'lstm'\n", + "\n", + "# Dictionary to set the parameters.\n", + "params = {\n", + " \"NUM_HIDDEN\" : 128,\n", + " \"NUM_TIMESTEPS\" : 8, #subinstance length.\n", + " \"ORIGINAL_NUM_TIMESTEPS\" : 20,\n", + " \"NUM_FEATS\" : 22,\n", + " \"FORGET_BIAS\" : 1.0,\n", + " \"NUM_OUTPUT\" : 2,\n", + " \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n", + " \"KEEP_PROB\" : 0.75,\n", + " \"PREFETCH_NUM\" : 5,\n", + " \"BATCH_SIZE\" : 32,\n", + " \"NUM_EPOCHS\" : 2,\n", + " \"NUM_ITER\" : 4,\n", + " \"NUM_ROUNDS\" : 30,\n", + " \"LEARNING_RATE\" : 0.001,\n", + " \"FRAC_EMI\" : 0.5,\n", + " \"MODEL_PREFIX\" : '/home/sf/data/SWELL-KW/models/model-lstm'\n", + "}\n", + "\n", + "#Preprocess data, and load the train,test and validation splits.\n", + "lstm_dict = lstm_experiment_generator(params, path)\n", + "\n", + "#Create the directory to store the results of this run.\n", + "\n", + "dirname = \"\" + model\n", + "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n", + "print (\"Results for this run have been saved at\" , dirname, \".\")\n", + "\n", + "now = datetime.datetime.now()\n", + "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n", + "filename = ''.join(filename)\n", + "\n", + "#Save the dictionary containing the params and the results.\n", + "pkl.dump(lstm_dict,open(dirname + \"/lstm_dict_\" + filename + \".pkl\",mode='wb'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}