a b/NewDatasetConvnet.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 1,
6
   "metadata": {
7
    "collapsed": true
8
   },
9
   "outputs": [],
10
   "source": [
11
    "from scipy import io\n",
12
    "from scipy.signal import butter, lfilter\n",
13
    "import h5py\n",
14
    "import random\n",
15
    "import numpy as np\n",
16
    "import os"
17
   ]
18
  },
19
  {
20
   "cell_type": "code",
21
   "execution_count": 2,
22
   "metadata": {
23
    "collapsed": true
24
   },
25
   "outputs": [],
26
   "source": [
27
    "datafolder = \"new_dataset/\""
28
   ]
29
  },
30
  {
31
   "cell_type": "code",
32
   "execution_count": 3,
33
   "metadata": {
34
    "collapsed": true
35
   },
36
   "outputs": [],
37
   "source": [
38
    "# some filtering code copypasted from provided notebook \n",
39
    "\n",
40
    "def butter_bandpass(lowcut, highcut, sampling_rate, order=5):\n",
41
    "    nyq_freq = sampling_rate*0.5\n",
42
    "    low = lowcut/nyq_freq\n",
43
    "    high = highcut/nyq_freq\n",
44
    "    b, a = butter(order, [low, high], btype='band')\n",
45
    "    return b, a\n",
46
    "\n",
47
    "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n",
48
    "    nyq_freq = sampling_rate*0.5\n",
49
    "    lower_bound = lowcut/nyq_freq\n",
50
    "    higher_bound = highcut/nyq_freq\n",
51
    "    b_high, a_high = butter(order, lower_bound, btype='high')\n",
52
    "    b_low, a_low = butter(order, higher_bound, btype='low')\n",
53
    "    return b_high, a_high, b_low, a_low\n",
54
    "\n",
55
    "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'separately'):\n",
56
    "    if how_to_filt == 'separately':\n",
57
    "        b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n",
58
    "        y = lfilter(b_high, a_high, data)\n",
59
    "        y = lfilter(b_low, a_low, y)\n",
60
    "    elif how_to_filt == 'simultaneously':\n",
61
    "        b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n",
62
    "        y = lfilter(b, a, data)\n",
63
    "    return y"
64
   ]
65
  },
66
  {
67
   "cell_type": "code",
68
   "execution_count": 4,
69
   "metadata": {
70
    "collapsed": true
71
   },
72
   "outputs": [],
73
   "source": [
74
    "def open_eeg_mat(filename, centered=True):\n",
75
    "    all_data = io.loadmat(filename)\n",
76
    "    eeg_data = all_data['data_cur']\n",
77
    "    if centered:\n",
78
    "        eeg_data = eeg_data - np.mean(eeg_data,1)[np.newaxis].T\n",
79
    "        print('Data were centered: channels are zero-mean')\n",
80
    "    states_labels = all_data['states_cur']\n",
81
    "    states_codes = list(np.unique(states_labels)[:])\n",
82
    "    sampling_rate = all_data['srate']\n",
83
    "    chan_names = all_data['chan_names']\n",
84
    "    return eeg_data, states_labels, sampling_rate, chan_names, eeg_data.shape[0], eeg_data.shape[1], states_codes\n",
85
    "\n",
86
    "def butter_high_low_pass(lowcut, highcut, sampling_rate, order=5):\n",
87
    "    nyq_freq = sampling_rate*0.5\n",
88
    "    lower_bound = lowcut/nyq_freq\n",
89
    "    higher_bound = highcut/nyq_freq\n",
90
    "    b_high, a_high = butter(order, lower_bound, btype='high')\n",
91
    "    b_low, a_low = butter(order, higher_bound, btype='low')\n",
92
    "    return b_high, a_high, b_low, a_low\n",
93
    "\n",
94
    "def butter_bandpass_filter(data, lowcut, highcut, sampling_rate, order=5, how_to_filt = 'simultaneously'):\n",
95
    "    if how_to_filt == 'separately':\n",
96
    "        b_high, a_high, b_low, a_low = butter_high_low_pass(lowcut, highcut, sampling_rate, order=order)\n",
97
    "        y = lfilter(b_high, a_high, data)\n",
98
    "        y = lfilter(b_low, a_low, y)\n",
99
    "    elif how_to_filt == 'simultaneously':\n",
100
    "        b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)\n",
101
    "        y = lfilter(b, a, data)\n",
102
    "    return y"
103
   ]
104
  },
105
  {
106
   "cell_type": "code",
107
   "execution_count": 5,
108
   "metadata": {
109
    "collapsed": false
110
   },
111
   "outputs": [
112
    {
113
     "name": "stdout",
114
     "output_type": "stream",
115
     "text": [
116
      "[1 1 1 ..., 2 2 2]\n",
117
      "[1 1 1 ..., 6 6 6]\n",
118
      "[6 6 6 ..., 6 6 6]\n",
119
      "[1 1 1 ..., 6 6 6]\n",
120
      "[1 1 1 ..., 2 2 2]\n",
121
      "[1 1 1 ..., 2 2 2]\n",
122
      "[1 1 1 ..., 6 6 6]\n",
123
      "[1 1 1 ..., 6 6 6]\n",
124
      "[1 1 1 ..., 6 6 6]\n",
125
      "[6 6 6 ..., 6 6 6]\n",
126
      "[1 1 1 ..., 2 2 2]\n",
127
      "[1 1 1 ..., 2 2 2]\n",
128
      "[1 1 1 ..., 2 2 2]\n",
129
      "[1 1 1 ..., 6 6 6]\n",
130
      "[1 1 1 ..., 6 6 6]\n",
131
      "[6 6 6 ..., 6 6 6]\n",
132
      "[1 1 1 ..., 6 6 6]\n",
133
      "[1 1 1 ..., 2 2 2]\n",
134
      "[1 1 1 ..., 2 2 2]\n",
135
      "[6 6 6 ..., 2 2 2]\n"
136
     ]
137
    }
138
   ],
139
   "source": [
140
    "train_datas = {}\n",
141
    "test_datas = {}\n",
142
    "\n",
143
    "def to_onehot(label):\n",
144
    "    labels_encoding = {1: np.array([1,0,0]), 2: np.array([0,1,0]), 6: np.array([0,0,1])}\n",
145
    "    return labels_encoding[label]\n",
146
    "\n",
147
    "for fname in os.listdir(datafolder):\n",
148
    "    filename = datafolder + fname\n",
149
    "    [eeg_data, states_labels, sampling_rate, chan_names, chan_numb, samp_numb, states_codes] = open_eeg_mat(filename, centered=False)\n",
150
    "    sampling_rate = sampling_rate[0,0]\n",
151
    "    eeg_data = butter_bandpass_filter(eeg_data, 0.5, 45, sampling_rate, order=5, how_to_filt = 'simultaneously')\n",
152
    "    \n",
153
    "    states_labels = states_labels[0]\n",
154
    "    print(states_labels)\n",
155
    "    states_labels = states_labels[2000:-2000]\n",
156
    "    eeg_data = eeg_data[:,2000:-2000]\n",
157
    "    \n",
158
    "    experiment_name = \"_\".join(fname.split(\"_\")[:-1])\n",
159
    "    if fname.endswith(\"_2.mat\"):\n",
160
    "        test_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}\n",
161
    "    elif fname.endswith(\"_1.mat\"):\n",
162
    "        train_datas[experiment_name] = {\"eeg_data\": eeg_data.T, \"labels\": states_labels}"
163
   ]
164
  },
165
  {
166
   "cell_type": "code",
167
   "execution_count": 6,
168
   "metadata": {
169
    "collapsed": false
170
   },
171
   "outputs": [],
172
   "source": [
173
    "# separate scaling for each user, should not hurt \n",
174
    "from sklearn.preprocessing import StandardScaler\n",
175
    "\n",
176
    "for key in train_datas.keys():\n",
177
    "    sc = StandardScaler()\n",
178
    "    train_datas[key][\"eeg_data\"] = sc.fit_transform(train_datas[key][\"eeg_data\"])\n",
179
    "    test_datas[key][\"eeg_data\"] = sc.fit_transform(test_datas[key][\"eeg_data\"])"
180
   ]
181
  },
182
  {
183
   "cell_type": "code",
184
   "execution_count": 7,
185
   "metadata": {
186
    "collapsed": true
187
   },
188
   "outputs": [],
189
   "source": [
190
    "slice_len = 500"
191
   ]
192
  },
193
  {
194
   "cell_type": "code",
195
   "execution_count": 8,
196
   "metadata": {
197
    "collapsed": true
198
   },
199
   "outputs": [],
200
   "source": [
201
    "def generate_slice(test=False):\n",
202
    "    if test:\n",
203
    "        experiment_data = random.choice(list(test_datas.values()))\n",
204
    "    else:\n",
205
    "        experiment_data = random.choice(list(train_datas.values()))\n",
206
    "    \n",
207
    "    X = experiment_data[\"eeg_data\"]\n",
208
    "    y = experiment_data[\"labels\"]\n",
209
    "    \n",
210
    "    while True:\n",
211
    "        slice_start = np.random.choice(len(X) - slice_len)\n",
212
    "        slice_end = slice_start + slice_len\n",
213
    "        slice_x = X[slice_start:slice_end]\n",
214
    "        #slice_x = normalize(slice_x)\n",
215
    "        slice_y = y[slice_start:slice_end]\n",
216
    "        \n",
217
    "        if len(set(slice_y)) == 1:\n",
218
    "            return slice_x, to_onehot(slice_y[-1])"
219
   ]
220
  },
221
  {
222
   "cell_type": "code",
223
   "execution_count": 9,
224
   "metadata": {
225
    "collapsed": false
226
   },
227
   "outputs": [
228
    {
229
     "data": {
230
      "text/plain": [
231
       "(500, 24)"
232
      ]
233
     },
234
     "execution_count": 9,
235
     "metadata": {},
236
     "output_type": "execute_result"
237
    }
238
   ],
239
   "source": [
240
    "generate_slice()[0].shape"
241
   ]
242
  },
243
  {
244
   "cell_type": "code",
245
   "execution_count": 10,
246
   "metadata": {
247
    "collapsed": true
248
   },
249
   "outputs": [],
250
   "source": [
251
    "def data_generator(batch_size, test=False):\n",
252
    "    while True:\n",
253
    "        batch_x = []\n",
254
    "        batch_y = []\n",
255
    "        \n",
256
    "        for i in range(0, batch_size):\n",
257
    "            x, y = generate_slice(test=test)\n",
258
    "            batch_x.append(x)\n",
259
    "            batch_y.append(y)\n",
260
    "            \n",
261
    "        y = np.array(batch_y)\n",
262
    "        x = np.array([i for i in batch_x])\n",
263
    "        yield (x, y)"
264
   ]
265
  },
266
  {
267
   "cell_type": "code",
268
   "execution_count": 11,
269
   "metadata": {
270
    "collapsed": false
271
   },
272
   "outputs": [
273
    {
274
     "name": "stderr",
275
     "output_type": "stream",
276
     "text": [
277
      "Using TensorFlow backend.\n"
278
     ]
279
    }
280
   ],
281
   "source": [
282
    "from keras.layers import Convolution1D, Dense, Dropout, Input, merge, GlobalMaxPooling1D, MaxPooling1D, Flatten, LSTM\n",
283
    "from keras.models import Model, load_model\n",
284
    "from keras.optimizers import RMSprop"
285
   ]
286
  },
287
  {
288
   "cell_type": "code",
289
   "execution_count": 12,
290
   "metadata": {
291
    "collapsed": true
292
   },
293
   "outputs": [],
294
   "source": [
295
    "def get_base_model(input_len, fsize):\n",
296
    "    '''Base network to be shared (eq. to feature extraction).\n",
297
    "    '''\n",
298
    "    input_seq = Input(shape=(input_len, 24))\n",
299
    "    nb_filters = 50\n",
300
    "    convolved = Convolution1D(nb_filters, 5, border_mode=\"same\", activation=\"tanh\")(input_seq)\n",
301
    "    pooled = GlobalMaxPooling1D()(convolved)\n",
302
    "    compressed = Dense(50, activation=\"linear\")(pooled)\n",
303
    "    compressed = Dropout(0.3)(compressed)\n",
304
    "    compressed = Dense(50, activation=\"relu\")(compressed)\n",
305
    "    compressed = Dropout(0.3)(compressed)\n",
306
    "    model = Model(input=input_seq, output=compressed)            \n",
307
    "    return model"
308
   ]
309
  },
310
  {
311
   "cell_type": "code",
312
   "execution_count": 13,
313
   "metadata": {
314
    "collapsed": false
315
   },
316
   "outputs": [],
317
   "source": [
318
    "input1125_seq = Input(shape=(slice_len, 24))\n",
319
    "\n",
320
    "base_network1125 = get_base_model(slice_len, 10)\n",
321
    "\n",
322
    "embedding_1125 = base_network1125(input1125_seq)\n",
323
    "out = Dense(3, activation='softmax')(embedding_1125)\n",
324
    "    \n",
325
    "model = Model(input=input1125_seq, output=out)\n",
326
    "    \n",
327
    "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"categorical_accuracy\"])"
328
   ]
329
  },
330
  {
331
   "cell_type": "code",
332
   "execution_count": 14,
333
   "metadata": {
334
    "collapsed": false
335
   },
336
   "outputs": [
337
    {
338
     "name": "stdout",
339
     "output_type": "stream",
340
     "text": [
341
      "Epoch 1/1\n",
342
      "47s - loss: 0.8520 - categorical_accuracy: 0.5665 - val_loss: 0.7560 - val_categorical_accuracy: 0.6100\n"
343
     ]
344
    },
345
    {
346
     "data": {
347
      "text/plain": [
348
       "<keras.callbacks.History at 0x7f1f07163668>"
349
      ]
350
     },
351
     "execution_count": 14,
352
     "metadata": {},
353
     "output_type": "execute_result"
354
    }
355
   ],
356
   "source": [
357
    "from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
358
    "\n",
359
    "nb_epoch = 100000\n",
360
    "earlyStopping = EarlyStopping(monitor='categorical_accuracy', patience=10, verbose=0, mode='auto')\n",
361
    "checkpointer = ModelCheckpoint(\"convlstm_alldata.h5\", monitor='categorical_accuracy', verbose=0,\n",
362
    "                               save_best_only=True, mode='auto', period=1)\n",
363
    "\n",
364
    "samples_per_epoch = 15000\n",
365
    "nb_epoch = 1\n",
366
    "\n",
367
    "model.fit_generator(data_generator(batch_size=25), samples_per_epoch, nb_epoch, \n",
368
    "                    callbacks=[earlyStopping, checkpointer], verbose=2, nb_val_samples=15000,\n",
369
    "                    validation_data=data_generator(batch_size=25, test=True))"
370
   ]
371
  }
372
 ],
373
 "metadata": {
374
  "kernelspec": {
375
   "display_name": "Python 3",
376
   "language": "python",
377
   "name": "python3"
378
  },
379
  "language_info": {
380
   "codemirror_mode": {
381
    "name": "ipython",
382
    "version": 3
383
   },
384
   "file_extension": ".py",
385
   "mimetype": "text/x-python",
386
   "name": "python",
387
   "nbconvert_exporter": "python",
388
   "pygments_lexer": "ipython3",
389
   "version": "3.5.2"
390
  }
391
 },
392
 "nbformat": 4,
393
 "nbformat_minor": 2
394
}