--- a +++ b/NewDatasetConvnet.ipynb @@ -0,0 +1,394 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from scipy import io\n", + "from scipy.signal import butter, lfilter\n", + "import h5py\n", + "import random\n", + "import numpy as np\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "datafolder = \"new_dataset/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# some filtering code copypasted from provided notebook \n", + "\n", + "def butter_bandpass(lowcut, highcut, sampling_rate, order=5):\n", + " nyq_freq = sampling_rate*0.5\n", + " low = lowcut/nyq_freq\n", + " high = highcut/nyq_freq\n", + " b, a = butter(order, [low, high], btype='band')\n", + " return b, a\n", + "\n", + "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n", + " nyq_freq = sampling_rate*0.5\n", + " lower_bound = lowcut/nyq_freq\n", + " higher_bound = highcut/nyq_freq\n", + " b_high, a_high = butter(order, lower_bound, btype='high')\n", + " b_low, a_low = butter(order, higher_bound, btype='low')\n", + " return b_high, a_high, b_low, a_low\n", + "\n", + "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'separately'):\n", + " if how_to_filt == 'separately':\n", + " b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n", + " y = lfilter(b_high, a_high, data)\n", + " y = lfilter(b_low, a_low, y)\n", + " elif how_to_filt == 'simultaneously':\n", + " b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n", + " y = lfilter(b, a, data)\n", + " return y" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def open_eeg_mat(filename, centered=True):\n", + " all_data = io.loadmat(filename)\n", + " eeg_data = all_data['data_cur']\n", + " if centered:\n", + " eeg_data = eeg_data - np.mean(eeg_data,1)[np.newaxis].T\n", + " print('Data were centered: channels are zero-mean')\n", + " states_labels = all_data['states_cur']\n", + " states_codes = list(np.unique(states_labels)[:])\n", + " sampling_rate = all_data['srate']\n", + " chan_names = all_data['chan_names']\n", + " return eeg_data, states_labels, sampling_rate, chan_names, eeg_data.shape[0], eeg_data.shape[1], states_codes\n", + "\n", + "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n", + " nyq_freq = sampling_rate*0.5\n", + " lower_bound = lowcut/nyq_freq\n", + " higher_bound = highcut/nyq_freq\n", + " b_high, a_high = butter(order, lower_bound, btype='high')\n", + " b_low, a_low = butter(order, higher_bound, btype='low')\n", + " return b_high, a_high, b_low, a_low\n", + "\n", + "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'simultaneously'):\n", + " if how_to_filt == 'separately':\n", + " b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n", + " y = lfilter(b_high, a_high, data)\n", + " y = lfilter(b_low, a_low, y)\n", + " elif how_to_filt == 'simultaneously':\n", + " b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n", + " y = lfilter(b, a, data)\n", + " return y" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1 1 1 ..., 2 2 2]\n", + "[1 1 1 ..., 6 6 6]\n", + "[6 6 6 ..., 6 6 6]\n", + "[1 1 1 ..., 6 6 6]\n", + "[1 1 1 ..., 2 2 2]\n", + "[1 1 1 ..., 2 2 2]\n", + "[1 1 1 ..., 6 6 6]\n", + "[1 1 1 ..., 6 6 6]\n", + "[1 1 1 ..., 6 6 6]\n", + "[6 6 6 ..., 6 6 6]\n", + "[1 1 1 ..., 2 2 2]\n", + "[1 1 1 ..., 2 2 2]\n", + "[1 1 1 ..., 2 2 2]\n", + "[1 1 1 ..., 6 6 6]\n", + "[1 1 1 ..., 6 6 6]\n", + "[6 6 6 ..., 6 6 6]\n", + "[1 1 1 ..., 6 6 6]\n", + "[1 1 1 ..., 2 2 2]\n", + "[1 1 1 ..., 2 2 2]\n", + "[6 6 6 ..., 2 2 2]\n" + ] + } + ], + "source": [ + "train_datas = {}\n", + "test_datas = {}\n", + "\n", + "def to_onehot(label):\n", + " labels_encoding = {1: np.array([1,0,0]), 2: np.array([0,1,0]), 6: np.array([0,0,1])}\n", + " return labels_encoding[label]\n", + "\n", + "for fname in os.listdir(datafolder):\n", + " filename = datafolder + fname\n", + " [eeg_data, states_labels, sampling_rate, chan_names, chan_numb, samp_numb, states_codes] = open_eeg_mat(filename, centered=False)\n", + " sampling_rate = sampling_rate[0,0]\n", + " eeg_data = butter_bandpass_filter(eeg_data, 0.5, 45, sampling_rate, order=5, how_to_filt = 'simultaneously')\n", + " \n", + " states_labels = states_labels[0]\n", + " print(states_labels)\n", + " states_labels = states_labels[2000:-2000]\n", + " eeg_data = eeg_data[:,2000:-2000]\n", + " \n", + " experiment_name = \"_\".join(fname.split(\"_\")[:-1])\n", + " if fname.endswith(\"_2.mat\"):\n", + " test_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}\n", + " elif fname.endswith(\"_1.mat\"):\n", + " train_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# separate scaling for each user, should not hurt \n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "for key in train_datas.keys():\n", + " sc = StandardScaler()\n", + " train_datas[key][\"eeg_data\"] = sc.fit_transform(train_datas[key][\"eeg_data\"])\n", + " test_datas[key][\"eeg_data\"] = sc.fit_transform(test_datas[key][\"eeg_data\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "slice_len = 500" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def generate_slice(test=False):\n", + " if test:\n", + " experiment_data = random.choice(list(test_datas.values()))\n", + " else:\n", + " experiment_data = random.choice(list(train_datas.values()))\n", + " \n", + " X = experiment_data[\"eeg_data\"]\n", + " y = experiment_data[\"labels\"]\n", + " \n", + " while True:\n", + " slice_start = np.random.choice(len(X) - slice_len)\n", + " slice_end = slice_start + slice_len\n", + " slice_x = X[slice_start:slice_end]\n", + " #slice_x = normalize(slice_x)\n", + " slice_y = y[slice_start:slice_end]\n", + " \n", + " if len(set(slice_y)) == 1:\n", + " return slice_x, to_onehot(slice_y[-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(500, 24)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "generate_slice()[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def data_generator(batch_size, test=False):\n", + " while True:\n", + " batch_x = []\n", + " batch_y = []\n", + " \n", + " for i in range(0, batch_size):\n", + " x, y = generate_slice(test=test)\n", + " batch_x.append(x)\n", + " batch_y.append(y)\n", + " \n", + " y = np.array(batch_y)\n", + " x = np.array([i for i in batch_x])\n", + " yield (x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "from keras.layers import Convolution1D, Dense, Dropout, Input, merge, GlobalMaxPooling1D, MaxPooling1D, Flatten, LSTM\n", + "from keras.models import Model, load_model\n", + "from keras.optimizers import RMSprop" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def get_base_model(input_len, fsize):\n", + " '''Base network to be shared (eq. to feature extraction).\n", + " '''\n", + " input_seq = Input(shape=(input_len, 24))\n", + " nb_filters = 50\n", + " convolved = Convolution1D(nb_filters, 5, border_mode=\"same\", activation=\"tanh\")(input_seq)\n", + " pooled = GlobalMaxPooling1D()(convolved)\n", + " compressed = Dense(50, activation=\"linear\")(pooled)\n", + " compressed = Dropout(0.3)(compressed)\n", + " compressed = Dense(50, activation=\"relu\")(compressed)\n", + " compressed = Dropout(0.3)(compressed)\n", + " model = Model(input=input_seq, output=compressed) \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "input1125_seq = Input(shape=(slice_len, 24))\n", + "\n", + "base_network1125 = get_base_model(slice_len, 10)\n", + "\n", + "embedding_1125 = base_network1125(input1125_seq)\n", + "out = Dense(3, activation='softmax')(embedding_1125)\n", + " \n", + "model = Model(input=input1125_seq, output=out)\n", + " \n", + "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"categorical_accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/1\n", + "47s - loss: 0.8520 - categorical_accuracy: 0.5665 - val_loss: 0.7560 - val_categorical_accuracy: 0.6100\n" + ] + }, + { + "data": { + "text/plain": [ + "<keras.callbacks.History at 0x7f1f07163668>" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from keras.callbacks import EarlyStopping, ModelCheckpoint\n", + "\n", + "nb_epoch = 100000\n", + "earlyStopping = EarlyStopping(monitor='categorical_accuracy', patience=10, verbose=0, mode='auto')\n", + "checkpointer = ModelCheckpoint(\"convlstm_alldata.h5\", monitor='categorical_accuracy', verbose=0,\n", + " save_best_only=True, mode='auto', period=1)\n", + "\n", + "samples_per_epoch = 15000\n", + "nb_epoch = 1\n", + "\n", + "model.fit_generator(data_generator(batch_size=25), samples_per_epoch, nb_epoch, \n", + " callbacks=[earlyStopping, checkpointer], verbose=2, nb_val_samples=15000,\n", + " validation_data=data_generator(batch_size=25, test=True))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}