Diff of /NewDatasetConvnet.ipynb [000000] .. [a8e9b4]

Switch to side-by-side view

--- a
+++ b/NewDatasetConvnet.ipynb
@@ -0,0 +1,394 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "from scipy import io\n",
+    "from scipy.signal import butter, lfilter\n",
+    "import h5py\n",
+    "import random\n",
+    "import numpy as np\n",
+    "import os"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "datafolder = \"new_dataset/\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "# some filtering code copypasted from provided notebook \n",
+    "\n",
+    "def butter_bandpass(lowcut, highcut, sampling_rate, order=5):\n",
+    "    nyq_freq = sampling_rate*0.5\n",
+    "    low = lowcut/nyq_freq\n",
+    "    high = highcut/nyq_freq\n",
+    "    b, a = butter(order, [low, high], btype='band')\n",
+    "    return b, a\n",
+    "\n",
+    "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n",
+    "    nyq_freq = sampling_rate*0.5\n",
+    "    lower_bound = lowcut/nyq_freq\n",
+    "    higher_bound = highcut/nyq_freq\n",
+    "    b_high, a_high = butter(order, lower_bound, btype='high')\n",
+    "    b_low, a_low = butter(order, higher_bound, btype='low')\n",
+    "    return b_high, a_high, b_low, a_low\n",
+    "\n",
+    "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'separately'):\n",
+    "    if how_to_filt == 'separately':\n",
+    "        b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n",
+    "        y = lfilter(b_high, a_high, data)\n",
+    "        y = lfilter(b_low, a_low, y)\n",
+    "    elif how_to_filt == 'simultaneously':\n",
+    "        b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n",
+    "        y = lfilter(b, a, data)\n",
+    "    return y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "def open_eeg_mat(filename, centered=True):\n",
+    "    all_data = io.loadmat(filename)\n",
+    "    eeg_data = all_data['data_cur']\n",
+    "    if centered:\n",
+    "        eeg_data = eeg_data - np.mean(eeg_data,1)[np.newaxis].T\n",
+    "        print('Data were centered: channels are zero-mean')\n",
+    "    states_labels = all_data['states_cur']\n",
+    "    states_codes = list(np.unique(states_labels)[:])\n",
+    "    sampling_rate = all_data['srate']\n",
+    "    chan_names = all_data['chan_names']\n",
+    "    return eeg_data, states_labels, sampling_rate, chan_names, eeg_data.shape[0], eeg_data.shape[1], states_codes\n",
+    "\n",
+    "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n",
+    "    nyq_freq = sampling_rate*0.5\n",
+    "    lower_bound = lowcut/nyq_freq\n",
+    "    higher_bound = highcut/nyq_freq\n",
+    "    b_high, a_high = butter(order, lower_bound, btype='high')\n",
+    "    b_low, a_low = butter(order, higher_bound, btype='low')\n",
+    "    return b_high, a_high, b_low, a_low\n",
+    "\n",
+    "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'simultaneously'):\n",
+    "    if how_to_filt == 'separately':\n",
+    "        b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n",
+    "        y = lfilter(b_high, a_high, data)\n",
+    "        y = lfilter(b_low, a_low, y)\n",
+    "    elif how_to_filt == 'simultaneously':\n",
+    "        b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n",
+    "        y = lfilter(b, a, data)\n",
+    "    return y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "collapsed": false
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[1 1 1 ..., 2 2 2]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[6 6 6 ..., 6 6 6]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[1 1 1 ..., 2 2 2]\n",
+      "[1 1 1 ..., 2 2 2]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[6 6 6 ..., 6 6 6]\n",
+      "[1 1 1 ..., 2 2 2]\n",
+      "[1 1 1 ..., 2 2 2]\n",
+      "[1 1 1 ..., 2 2 2]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[6 6 6 ..., 6 6 6]\n",
+      "[1 1 1 ..., 6 6 6]\n",
+      "[1 1 1 ..., 2 2 2]\n",
+      "[1 1 1 ..., 2 2 2]\n",
+      "[6 6 6 ..., 2 2 2]\n"
+     ]
+    }
+   ],
+   "source": [
+    "train_datas = {}\n",
+    "test_datas = {}\n",
+    "\n",
+    "def to_onehot(label):\n",
+    "    labels_encoding = {1: np.array([1,0,0]), 2: np.array([0,1,0]), 6: np.array([0,0,1])}\n",
+    "    return labels_encoding[label]\n",
+    "\n",
+    "for fname in os.listdir(datafolder):\n",
+    "    filename = datafolder + fname\n",
+    "    [eeg_data, states_labels, sampling_rate, chan_names, chan_numb, samp_numb, states_codes] = open_eeg_mat(filename, centered=False)\n",
+    "    sampling_rate = sampling_rate[0,0]\n",
+    "    eeg_data = butter_bandpass_filter(eeg_data, 0.5, 45, sampling_rate, order=5, how_to_filt = 'simultaneously')\n",
+    "    \n",
+    "    states_labels = states_labels[0]\n",
+    "    print(states_labels)\n",
+    "    states_labels = states_labels[2000:-2000]\n",
+    "    eeg_data = eeg_data[:,2000:-2000]\n",
+    "    \n",
+    "    experiment_name = \"_\".join(fname.split(\"_\")[:-1])\n",
+    "    if fname.endswith(\"_2.mat\"):\n",
+    "        test_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}\n",
+    "    elif fname.endswith(\"_1.mat\"):\n",
+    "        train_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "collapsed": false
+   },
+   "outputs": [],
+   "source": [
+    "# separate scaling for each user, should not hurt \n",
+    "from sklearn.preprocessing import StandardScaler\n",
+    "\n",
+    "for key in train_datas.keys():\n",
+    "    sc = StandardScaler()\n",
+    "    train_datas[key][\"eeg_data\"] = sc.fit_transform(train_datas[key][\"eeg_data\"])\n",
+    "    test_datas[key][\"eeg_data\"] = sc.fit_transform(test_datas[key][\"eeg_data\"])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "slice_len = 500"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "def generate_slice(test=False):\n",
+    "    if test:\n",
+    "        experiment_data = random.choice(list(test_datas.values()))\n",
+    "    else:\n",
+    "        experiment_data = random.choice(list(train_datas.values()))\n",
+    "    \n",
+    "    X = experiment_data[\"eeg_data\"]\n",
+    "    y = experiment_data[\"labels\"]\n",
+    "    \n",
+    "    while True:\n",
+    "        slice_start = np.random.choice(len(X) - slice_len)\n",
+    "        slice_end = slice_start + slice_len\n",
+    "        slice_x = X[slice_start:slice_end]\n",
+    "        #slice_x = normalize(slice_x)\n",
+    "        slice_y = y[slice_start:slice_end]\n",
+    "        \n",
+    "        if len(set(slice_y)) == 1:\n",
+    "            return slice_x, to_onehot(slice_y[-1])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "collapsed": false
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(500, 24)"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "generate_slice()[0].shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "def data_generator(batch_size, test=False):\n",
+    "    while True:\n",
+    "        batch_x = []\n",
+    "        batch_y = []\n",
+    "        \n",
+    "        for i in range(0, batch_size):\n",
+    "            x, y = generate_slice(test=test)\n",
+    "            batch_x.append(x)\n",
+    "            batch_y.append(y)\n",
+    "            \n",
+    "        y = np.array(batch_y)\n",
+    "        x = np.array([i for i in batch_x])\n",
+    "        yield (x, y)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {
+    "collapsed": false
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Using TensorFlow backend.\n"
+     ]
+    }
+   ],
+   "source": [
+    "from keras.layers import Convolution1D, Dense, Dropout, Input, merge, GlobalMaxPooling1D, MaxPooling1D, Flatten, LSTM\n",
+    "from keras.models import Model, load_model\n",
+    "from keras.optimizers import RMSprop"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "def get_base_model(input_len, fsize):\n",
+    "    '''Base network to be shared (eq. to feature extraction).\n",
+    "    '''\n",
+    "    input_seq = Input(shape=(input_len, 24))\n",
+    "    nb_filters = 50\n",
+    "    convolved = Convolution1D(nb_filters, 5, border_mode=\"same\", activation=\"tanh\")(input_seq)\n",
+    "    pooled = GlobalMaxPooling1D()(convolved)\n",
+    "    compressed = Dense(50, activation=\"linear\")(pooled)\n",
+    "    compressed = Dropout(0.3)(compressed)\n",
+    "    compressed = Dense(50, activation=\"relu\")(compressed)\n",
+    "    compressed = Dropout(0.3)(compressed)\n",
+    "    model = Model(input=input_seq, output=compressed)            \n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {
+    "collapsed": false
+   },
+   "outputs": [],
+   "source": [
+    "input1125_seq = Input(shape=(slice_len, 24))\n",
+    "\n",
+    "base_network1125 = get_base_model(slice_len, 10)\n",
+    "\n",
+    "embedding_1125 = base_network1125(input1125_seq)\n",
+    "out = Dense(3, activation='softmax')(embedding_1125)\n",
+    "    \n",
+    "model = Model(input=input1125_seq, output=out)\n",
+    "    \n",
+    "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"categorical_accuracy\"])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {
+    "collapsed": false
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 1/1\n",
+      "47s - loss: 0.8520 - categorical_accuracy: 0.5665 - val_loss: 0.7560 - val_categorical_accuracy: 0.6100\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<keras.callbacks.History at 0x7f1f07163668>"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
+    "\n",
+    "nb_epoch = 100000\n",
+    "earlyStopping = EarlyStopping(monitor='categorical_accuracy', patience=10, verbose=0, mode='auto')\n",
+    "checkpointer = ModelCheckpoint(\"convlstm_alldata.h5\", monitor='categorical_accuracy', verbose=0,\n",
+    "                               save_best_only=True, mode='auto', period=1)\n",
+    "\n",
+    "samples_per_epoch = 15000\n",
+    "nb_epoch = 1\n",
+    "\n",
+    "model.fit_generator(data_generator(batch_size=25), samples_per_epoch, nb_epoch, \n",
+    "                    callbacks=[earlyStopping, checkpointer], verbose=2, nb_val_samples=15000,\n",
+    "                    validation_data=data_generator(batch_size=25, test=True))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.5.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}