[a8e9b4]: / NewDatasetConvnet.ipynb

Download this file

395 lines (394 with data), 12.0 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from scipy import io\n",
    "from scipy.signal import butter, lfilter\n",
    "import h5py\n",
    "import random\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "datafolder = \"new_dataset/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# some filtering code copypasted from provided notebook \n",
    "\n",
    "def butter_bandpass(lowcut, highcut, sampling_rate, order=5):\n",
    "    nyq_freq = sampling_rate*0.5\n",
    "    low = lowcut/nyq_freq\n",
    "    high = highcut/nyq_freq\n",
    "    b, a = butter(order, [low, high], btype='band')\n",
    "    return b, a\n",
    "\n",
    "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n",
    "    nyq_freq = sampling_rate*0.5\n",
    "    lower_bound = lowcut/nyq_freq\n",
    "    higher_bound = highcut/nyq_freq\n",
    "    b_high, a_high = butter(order, lower_bound, btype='high')\n",
    "    b_low, a_low = butter(order, higher_bound, btype='low')\n",
    "    return b_high, a_high, b_low, a_low\n",
    "\n",
    "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'separately'):\n",
    "    if how_to_filt == 'separately':\n",
    "        b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n",
    "        y = lfilter(b_high, a_high, data)\n",
    "        y = lfilter(b_low, a_low, y)\n",
    "    elif how_to_filt == 'simultaneously':\n",
    "        b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n",
    "        y = lfilter(b, a, data)\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def open_eeg_mat(filename, centered=True):\n",
    "    all_data = io.loadmat(filename)\n",
    "    eeg_data = all_data['data_cur']\n",
    "    if centered:\n",
    "        eeg_data = eeg_data - np.mean(eeg_data,1)[np.newaxis].T\n",
    "        print('Data were centered: channels are zero-mean')\n",
    "    states_labels = all_data['states_cur']\n",
    "    states_codes = list(np.unique(states_labels)[:])\n",
    "    sampling_rate = all_data['srate']\n",
    "    chan_names = all_data['chan_names']\n",
    "    return eeg_data, states_labels, sampling_rate, chan_names, eeg_data.shape[0], eeg_data.shape[1], states_codes\n",
    "\n",
    "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n",
    "    nyq_freq = sampling_rate*0.5\n",
    "    lower_bound = lowcut/nyq_freq\n",
    "    higher_bound = highcut/nyq_freq\n",
    "    b_high, a_high = butter(order, lower_bound, btype='high')\n",
    "    b_low, a_low = butter(order, higher_bound, btype='low')\n",
    "    return b_high, a_high, b_low, a_low\n",
    "\n",
    "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'simultaneously'):\n",
    "    if how_to_filt == 'separately':\n",
    "        b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n",
    "        y = lfilter(b_high, a_high, data)\n",
    "        y = lfilter(b_low, a_low, y)\n",
    "    elif how_to_filt == 'simultaneously':\n",
    "        b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n",
    "        y = lfilter(b, a, data)\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 1 1 ..., 2 2 2]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[6 6 6 ..., 6 6 6]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[1 1 1 ..., 2 2 2]\n",
      "[1 1 1 ..., 2 2 2]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[6 6 6 ..., 6 6 6]\n",
      "[1 1 1 ..., 2 2 2]\n",
      "[1 1 1 ..., 2 2 2]\n",
      "[1 1 1 ..., 2 2 2]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[6 6 6 ..., 6 6 6]\n",
      "[1 1 1 ..., 6 6 6]\n",
      "[1 1 1 ..., 2 2 2]\n",
      "[1 1 1 ..., 2 2 2]\n",
      "[6 6 6 ..., 2 2 2]\n"
     ]
    }
   ],
   "source": [
    "train_datas = {}\n",
    "test_datas = {}\n",
    "\n",
    "def to_onehot(label):\n",
    "    labels_encoding = {1: np.array([1,0,0]), 2: np.array([0,1,0]), 6: np.array([0,0,1])}\n",
    "    return labels_encoding[label]\n",
    "\n",
    "for fname in os.listdir(datafolder):\n",
    "    filename = datafolder + fname\n",
    "    [eeg_data, states_labels, sampling_rate, chan_names, chan_numb, samp_numb, states_codes] = open_eeg_mat(filename, centered=False)\n",
    "    sampling_rate = sampling_rate[0,0]\n",
    "    eeg_data = butter_bandpass_filter(eeg_data, 0.5, 45, sampling_rate, order=5, how_to_filt = 'simultaneously')\n",
    "    \n",
    "    states_labels = states_labels[0]\n",
    "    print(states_labels)\n",
    "    states_labels = states_labels[2000:-2000]\n",
    "    eeg_data = eeg_data[:,2000:-2000]\n",
    "    \n",
    "    experiment_name = \"_\".join(fname.split(\"_\")[:-1])\n",
    "    if fname.endswith(\"_2.mat\"):\n",
    "        test_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}\n",
    "    elif fname.endswith(\"_1.mat\"):\n",
    "        train_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# separate scaling for each user, should not hurt \n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "for key in train_datas.keys():\n",
    "    sc = StandardScaler()\n",
    "    train_datas[key][\"eeg_data\"] = sc.fit_transform(train_datas[key][\"eeg_data\"])\n",
    "    test_datas[key][\"eeg_data\"] = sc.fit_transform(test_datas[key][\"eeg_data\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "slice_len = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generate_slice(test=False):\n",
    "    if test:\n",
    "        experiment_data = random.choice(list(test_datas.values()))\n",
    "    else:\n",
    "        experiment_data = random.choice(list(train_datas.values()))\n",
    "    \n",
    "    X = experiment_data[\"eeg_data\"]\n",
    "    y = experiment_data[\"labels\"]\n",
    "    \n",
    "    while True:\n",
    "        slice_start = np.random.choice(len(X) - slice_len)\n",
    "        slice_end = slice_start + slice_len\n",
    "        slice_x = X[slice_start:slice_end]\n",
    "        #slice_x = normalize(slice_x)\n",
    "        slice_y = y[slice_start:slice_end]\n",
    "        \n",
    "        if len(set(slice_y)) == 1:\n",
    "            return slice_x, to_onehot(slice_y[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(500, 24)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generate_slice()[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def data_generator(batch_size, test=False):\n",
    "    while True:\n",
    "        batch_x = []\n",
    "        batch_y = []\n",
    "        \n",
    "        for i in range(0, batch_size):\n",
    "            x, y = generate_slice(test=test)\n",
    "            batch_x.append(x)\n",
    "            batch_y.append(y)\n",
    "            \n",
    "        y = np.array(batch_y)\n",
    "        x = np.array([i for i in batch_x])\n",
    "        yield (x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import Convolution1D, Dense, Dropout, Input, merge, GlobalMaxPooling1D, MaxPooling1D, Flatten, LSTM\n",
    "from keras.models import Model, load_model\n",
    "from keras.optimizers import RMSprop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_base_model(input_len, fsize):\n",
    "    '''Base network to be shared (eq. to feature extraction).\n",
    "    '''\n",
    "    input_seq = Input(shape=(input_len, 24))\n",
    "    nb_filters = 50\n",
    "    convolved = Convolution1D(nb_filters, 5, border_mode=\"same\", activation=\"tanh\")(input_seq)\n",
    "    pooled = GlobalMaxPooling1D()(convolved)\n",
    "    compressed = Dense(50, activation=\"linear\")(pooled)\n",
    "    compressed = Dropout(0.3)(compressed)\n",
    "    compressed = Dense(50, activation=\"relu\")(compressed)\n",
    "    compressed = Dropout(0.3)(compressed)\n",
    "    model = Model(input=input_seq, output=compressed)            \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "input1125_seq = Input(shape=(slice_len, 24))\n",
    "\n",
    "base_network1125 = get_base_model(slice_len, 10)\n",
    "\n",
    "embedding_1125 = base_network1125(input1125_seq)\n",
    "out = Dense(3, activation='softmax')(embedding_1125)\n",
    "    \n",
    "model = Model(input=input1125_seq, output=out)\n",
    "    \n",
    "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"categorical_accuracy\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1\n",
      "47s - loss: 0.8520 - categorical_accuracy: 0.5665 - val_loss: 0.7560 - val_categorical_accuracy: 0.6100\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f1f07163668>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
    "\n",
    "nb_epoch = 100000\n",
    "earlyStopping = EarlyStopping(monitor='categorical_accuracy', patience=10, verbose=0, mode='auto')\n",
    "checkpointer = ModelCheckpoint(\"convlstm_alldata.h5\", monitor='categorical_accuracy', verbose=0,\n",
    "                               save_best_only=True, mode='auto', period=1)\n",
    "\n",
    "samples_per_epoch = 15000\n",
    "nb_epoch = 1\n",
    "\n",
    "model.fit_generator(data_generator(batch_size=25), samples_per_epoch, nb_epoch, \n",
    "                    callbacks=[earlyStopping, checkpointer], verbose=2, nb_val_samples=15000,\n",
    "                    validation_data=data_generator(batch_size=25, test=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}