Switch to unified view

a b/.ipynb_checkpoints/EEGNet-PyTorch-checkpoint.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 3,
6
   "metadata": {
7
    "collapsed": false
8
   },
9
   "outputs": [],
10
   "source": [
11
    "\"\"\"\n",
12
    "Written by, \n",
13
    "Sriram Ravindran, sriram@ucsd.edu\n",
14
    "\n",
15
    "Original paper - https://arxiv.org/abs/1611.08024\n",
16
    "\n",
17
    "Please reach out to me if you spot an error.\n",
18
    "\"\"\""
19
   ]
20
  },
21
  {
22
   "cell_type": "code",
23
   "execution_count": 4,
24
   "metadata": {
25
    "collapsed": true
26
   },
27
   "outputs": [],
28
   "source": [
29
    "import numpy as np\n",
30
    "from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score\n",
31
    "import torch\n",
32
    "import torch.nn as nn\n",
33
    "import torch.optim as optim\n",
34
    "from torch.autograd import Variable\n",
35
    "import torch.nn.functional as F\n",
36
    "import torch.optim as optim"
37
   ]
38
  },
39
  {
40
   "cell_type": "markdown",
41
   "metadata": {},
42
   "source": [
43
    "<p>Here's the description from the paper</p>\n",
44
    "<img src=\"EEGNet.png\" style=\"width: 700px; float:left;\">"
45
   ]
46
  },
47
  {
48
   "cell_type": "code",
49
   "execution_count": 12,
50
   "metadata": {
51
    "collapsed": false
52
   },
53
   "outputs": [
54
    {
55
     "name": "stdout",
56
     "output_type": "stream",
57
     "text": [
58
      "Variable containing:\n",
59
      " 0.7338\n",
60
      "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n",
61
      "\n"
62
     ]
63
    }
64
   ],
65
   "source": [
66
    "class EEGNet(nn.Module):\n",
67
    "    def __init__(self):\n",
68
    "        super(EEGNet, self).__init__()\n",
69
    "        self.T = 120\n",
70
    "        \n",
71
    "        # Layer 1\n",
72
    "        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)\n",
73
    "        self.batchnorm1 = nn.BatchNorm2d(16, False)\n",
74
    "        \n",
75
    "        # Layer 2\n",
76
    "        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))\n",
77
    "        self.conv2 = nn.Conv2d(1, 4, (2, 32))\n",
78
    "        self.batchnorm2 = nn.BatchNorm2d(4, False)\n",
79
    "        self.pooling2 = nn.MaxPool2d(2, 4)\n",
80
    "        \n",
81
    "        # Layer 3\n",
82
    "        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))\n",
83
    "        self.conv3 = nn.Conv2d(4, 4, (8, 4))\n",
84
    "        self.batchnorm3 = nn.BatchNorm2d(4, False)\n",
85
    "        self.pooling3 = nn.MaxPool2d((2, 4))\n",
86
    "        \n",
87
    "        # FC Layer\n",
88
    "        # NOTE: This dimension will depend on the number of timestamps per sample in your data.\n",
89
    "        # I have 120 timepoints. \n",
90
    "        self.fc1 = nn.Linear(4*2*7, 1)\n",
91
    "        \n",
92
    "\n",
93
    "    def forward(self, x):\n",
94
    "        # Layer 1\n",
95
    "        x = F.elu(self.conv1(x))\n",
96
    "        x = self.batchnorm1(x)\n",
97
    "        x = F.dropout(x, 0.25)\n",
98
    "        x = x.permute(0, 3, 1, 2)\n",
99
    "        \n",
100
    "        # Layer 2\n",
101
    "        x = self.padding1(x)\n",
102
    "        x = F.elu(self.conv2(x))\n",
103
    "        x = self.batchnorm2(x)\n",
104
    "        x = F.dropout(x, 0.25)\n",
105
    "        x = self.pooling2(x)\n",
106
    "        \n",
107
    "        # Layer 3\n",
108
    "        x = self.padding2(x)\n",
109
    "        x = F.elu(self.conv3(x))\n",
110
    "        x = self.batchnorm3(x)\n",
111
    "        x = F.dropout(x, 0.25)\n",
112
    "        x = self.pooling3(x)\n",
113
    "        \n",
114
    "        # FC Layer\n",
115
    "        x = x.view(-1, 4*2*7)\n",
116
    "        x = F.sigmoid(self.fc1(x))\n",
117
    "        return x\n",
118
    "\n",
119
    "\n",
120
    "net = EEGNet().cuda(0)\n",
121
    "print net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0)))\n",
122
    "criterion = nn.BCELoss()\n",
123
    "optimizer = optim.Adam(net.parameters())"
124
   ]
125
  },
126
  {
127
   "cell_type": "markdown",
128
   "metadata": {},
129
   "source": [
130
    "#### Evaluate function returns values of different criteria like accuracy, precision etc. \n",
131
    "In case you face memory overflow issues, use batch size to control how many samples get evaluated at one time. Use a batch_size that is a factor of length of samples. This ensures that you won't miss any samples."
132
   ]
133
  },
134
  {
135
   "cell_type": "code",
136
   "execution_count": 13,
137
   "metadata": {
138
    "collapsed": true
139
   },
140
   "outputs": [],
141
   "source": [
142
    "def evaluate(model, X, Y, params = [\"acc\"]):\n",
143
    "    results = []\n",
144
    "    batch_size = 100\n",
145
    "    \n",
146
    "    predicted = []\n",
147
    "    \n",
148
    "    for i in range(len(X)/batch_size):\n",
149
    "        s = i*batch_size\n",
150
    "        e = i*batch_size+batch_size\n",
151
    "        \n",
152
    "        inputs = Variable(torch.from_numpy(X[s:e]).cuda(0))\n",
153
    "        pred = model(inputs)\n",
154
    "        \n",
155
    "        predicted.append(pred.data.cpu().numpy())\n",
156
    "        \n",
157
    "        \n",
158
    "    inputs = Variable(torch.from_numpy(X).cuda(0))\n",
159
    "    predicted = model(inputs)\n",
160
    "    \n",
161
    "    predicted = predicted.data.cpu().numpy()\n",
162
    "    \n",
163
    "    for param in params:\n",
164
    "        if param == 'acc':\n",
165
    "            results.append(accuracy_score(Y, np.round(predicted)))\n",
166
    "        if param == \"auc\":\n",
167
    "            results.append(roc_auc_score(Y, predicted))\n",
168
    "        if param == \"recall\":\n",
169
    "            results.append(recall_score(Y, np.round(predicted)))\n",
170
    "        if param == \"precision\":\n",
171
    "            results.append(precision_score(Y, np.round(predicted)))\n",
172
    "        if param == \"fmeasure\":\n",
173
    "            precision = precision_score(Y, np.round(predicted))\n",
174
    "            recall = recall_score(Y, np.round(predicted))\n",
175
    "            results.append(2*precision*recall/ (precision+recall))\n",
176
    "    return results"
177
   ]
178
  },
179
  {
180
   "cell_type": "markdown",
181
   "metadata": {},
182
   "source": [
183
    "#### Generate random data\n",
184
    "\n",
185
    "##### Data format:\n",
186
    "Datatype - float32 (both X and Y) <br>\n",
187
    "X.shape - (#samples, 1, #timepoints,  #channels) <br>\n",
188
    "Y.shape - (#samples)"
189
   ]
190
  },
191
  {
192
   "cell_type": "code",
193
   "execution_count": 14,
194
   "metadata": {
195
    "collapsed": true
196
   },
197
   "outputs": [],
198
   "source": [
199
    "X_train = np.random.rand(100, 1, 120, 64).astype('float32') # np.random.rand generates between [0, 1)\n",
200
    "y_train = np.round(np.random.rand(100).astype('float32')) # binary data, so we round it to 0 or 1.\n",
201
    "\n",
202
    "X_val = np.random.rand(100, 1, 120, 64).astype('float32')\n",
203
    "y_val = np.round(np.random.rand(100).astype('float32'))\n",
204
    "\n",
205
    "X_test = np.random.rand(100, 1, 120, 64).astype('float32')\n",
206
    "y_test = np.round(np.random.rand(100).astype('float32'))"
207
   ]
208
  },
209
  {
210
   "cell_type": "markdown",
211
   "metadata": {},
212
   "source": [
213
    "#### Run"
214
   ]
215
  },
216
  {
217
   "cell_type": "code",
218
   "execution_count": 15,
219
   "metadata": {
220
    "collapsed": false
221
   },
222
   "outputs": [
223
    {
224
     "name": "stdout",
225
     "output_type": "stream",
226
     "text": [
227
      "\n",
228
      "Epoch  0\n",
229
      "['acc', 'auc', 'fmeasure']\n",
230
      "Training Loss  1.54113572836\n",
231
      "Train -  [0.54000000000000004, 0.59178743961352653, 0.70129870129870131]\n",
232
      "Validation -  [0.51000000000000001, 0.48539415766306526, 0.67549668874172186]\n",
233
      "Test -  [0.5, 0.50319999999999998, 0.66666666666666663]\n",
234
      "\n",
235
      "Epoch  1\n",
236
      "['acc', 'auc', 'fmeasure']\n",
237
      "Training Loss  1.42391115427\n",
238
      "Train -  [0.54000000000000004, 0.63888888888888895, 0.70129870129870131]\n",
239
      "Validation -  [0.51000000000000001, 0.47458983593437376, 0.67549668874172186]\n",
240
      "Test -  [0.5, 0.50439999999999996, 0.66666666666666663]\n",
241
      "\n",
242
      "Epoch  2\n",
243
      "['acc', 'auc', 'fmeasure']\n",
244
      "Training Loss  1.3422973156\n",
245
      "Train -  [0.55000000000000004, 0.67995169082125606, 0.70198675496688734]\n",
246
      "Validation -  [0.53000000000000003, 0.46898759503801518, 0.68456375838926176]\n",
247
      "Test -  [0.51000000000000001, 0.50800000000000001, 0.67114093959731547]\n",
248
      "\n",
249
      "Epoch  3\n",
250
      "['acc', 'auc', 'fmeasure']\n",
251
      "Training Loss  1.28801095486\n",
252
      "Train -  [0.63, 0.71054750402576483, 0.73758865248226957]\n",
253
      "Validation -  [0.48999999999999999, 0.4601840736294518, 0.63309352517985618]\n",
254
      "Test -  [0.52000000000000002, 0.51000000000000001, 0.66666666666666663]\n",
255
      "\n",
256
      "Epoch  4\n",
257
      "['acc', 'auc', 'fmeasure']\n",
258
      "Training Loss  1.25420039892\n",
259
      "Train -  [0.68999999999999995, 0.74476650563607083, 0.75590551181102361]\n",
260
      "Validation -  [0.42999999999999999, 0.44457783113245297, 0.53658536585365846]\n",
261
      "Test -  [0.51000000000000001, 0.51000000000000001, 0.6080000000000001]\n",
262
      "\n",
263
      "Epoch  5\n",
264
      "['acc', 'auc', 'fmeasure']\n",
265
      "Training Loss  1.22989922762\n",
266
      "Train -  [0.75, 0.77375201288244766, 0.77876106194690276]\n",
267
      "Validation -  [0.46000000000000002, 0.43937575030012005, 0.49056603773584906]\n",
268
      "Test -  [0.51000000000000001, 0.50440000000000007, 0.55855855855855863]\n",
269
      "\n",
270
      "Epoch  6\n",
271
      "['acc', 'auc', 'fmeasure']\n",
272
      "Training Loss  1.20727479458\n",
273
      "Train -  [0.76000000000000001, 0.79227053140096615, 0.7735849056603773]\n",
274
      "Validation -  [0.40999999999999998, 0.43897559023609439, 0.40404040404040403]\n",
275
      "Test -  [0.47999999999999998, 0.496, 0.5]\n",
276
      "\n",
277
      "Epoch  7\n",
278
      "['acc', 'auc', 'fmeasure']\n",
279
      "Training Loss  1.18265104294\n",
280
      "Train -  [0.76000000000000001, 0.81119162640901765, 0.7735849056603773]\n",
281
      "Validation -  [0.40999999999999998, 0.43897559023609445, 0.40404040404040403]\n",
282
      "Test -  [0.44, 0.48999999999999999, 0.44]\n",
283
      "\n",
284
      "Epoch  8\n",
285
      "['acc', 'auc', 'fmeasure']\n",
286
      "Training Loss  1.15454357862\n",
287
      "Train -  [0.80000000000000004, 0.8248792270531401, 0.81132075471698106]\n",
288
      "Validation -  [0.40000000000000002, 0.43497398959583838, 0.40000000000000002]\n",
289
      "Test -  [0.48999999999999999, 0.48560000000000003, 0.51428571428571423]\n",
290
      "\n",
291
      "Epoch  9\n",
292
      "['acc', 'auc', 'fmeasure']\n",
293
      "Training Loss  1.12422537804\n",
294
      "Train -  [0.81000000000000005, 0.83816425120772953, 0.82568807339449546]\n",
295
      "Validation -  [0.40999999999999998, 0.43177270908363347, 0.41584158415841577]\n",
296
      "Test -  [0.47999999999999998, 0.4768, 0.52727272727272734]\n"
297
     ]
298
    }
299
   ],
300
   "source": [
301
    "batch_size = 32\n",
302
    "\n",
303
    "for epoch in range(10):  # loop over the dataset multiple times\n",
304
    "    print \"\\nEpoch \", epoch\n",
305
    "    \n",
306
    "    running_loss = 0.0\n",
307
    "    for i in range(len(X_train)/batch_size-1):\n",
308
    "        s = i*batch_size\n",
309
    "        e = i*batch_size+batch_size\n",
310
    "        \n",
311
    "        inputs = torch.from_numpy(X_train[s:e])\n",
312
    "        labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)\n",
313
    "        \n",
314
    "        # wrap them in Variable\n",
315
    "        inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))\n",
316
    "\n",
317
    "        # zero the parameter gradients\n",
318
    "        optimizer.zero_grad()\n",
319
    "\n",
320
    "        # forward + backward + optimize\n",
321
    "        outputs = net(inputs)\n",
322
    "        loss = criterion(outputs, labels)\n",
323
    "        loss.backward()\n",
324
    "        \n",
325
    "        \n",
326
    "        optimizer.step()\n",
327
    "        \n",
328
    "        running_loss += loss.data[0]\n",
329
    "    \n",
330
    "    # Validation accuracy\n",
331
    "    params = [\"acc\", \"auc\", \"fmeasure\"]\n",
332
    "    print params\n",
333
    "    print \"Training Loss \", running_loss\n",
334
    "    print \"Train - \", evaluate(net, X_train, y_train, params)\n",
335
    "    print \"Validation - \", evaluate(net, X_val, y_val, params)\n",
336
    "    print \"Test - \", evaluate(net, X_test, y_test, params)"
337
   ]
338
  },
339
  {
340
   "cell_type": "code",
341
   "execution_count": null,
342
   "metadata": {
343
    "collapsed": true
344
   },
345
   "outputs": [],
346
   "source": []
347
  }
348
 ],
349
 "metadata": {
350
  "kernelspec": {
351
   "display_name": "Python 2",
352
   "language": "python",
353
   "name": "python2"
354
  },
355
  "language_info": {
356
   "codemirror_mode": {
357
    "name": "ipython",
358
    "version": 2
359
   },
360
   "file_extension": ".py",
361
   "mimetype": "text/x-python",
362
   "name": "python",
363
   "nbconvert_exporter": "python",
364
   "pygments_lexer": "ipython2",
365
   "version": "2.7.6"
366
  }
367
 },
368
 "nbformat": 4,
369
 "nbformat_minor": 0
370
}