|
a |
|
b/.ipynb_checkpoints/EEGNet-PyTorch-checkpoint.ipynb |
|
|
1 |
{ |
|
|
2 |
"cells": [ |
|
|
3 |
{ |
|
|
4 |
"cell_type": "code", |
|
|
5 |
"execution_count": 3, |
|
|
6 |
"metadata": { |
|
|
7 |
"collapsed": false |
|
|
8 |
}, |
|
|
9 |
"outputs": [], |
|
|
10 |
"source": [ |
|
|
11 |
"\"\"\"\n", |
|
|
12 |
"Written by, \n", |
|
|
13 |
"Sriram Ravindran, sriram@ucsd.edu\n", |
|
|
14 |
"\n", |
|
|
15 |
"Original paper - https://arxiv.org/abs/1611.08024\n", |
|
|
16 |
"\n", |
|
|
17 |
"Please reach out to me if you spot an error.\n", |
|
|
18 |
"\"\"\"" |
|
|
19 |
] |
|
|
20 |
}, |
|
|
21 |
{ |
|
|
22 |
"cell_type": "code", |
|
|
23 |
"execution_count": 4, |
|
|
24 |
"metadata": { |
|
|
25 |
"collapsed": true |
|
|
26 |
}, |
|
|
27 |
"outputs": [], |
|
|
28 |
"source": [ |
|
|
29 |
"import numpy as np\n", |
|
|
30 |
"from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score\n", |
|
|
31 |
"import torch\n", |
|
|
32 |
"import torch.nn as nn\n", |
|
|
33 |
"import torch.optim as optim\n", |
|
|
34 |
"from torch.autograd import Variable\n", |
|
|
35 |
"import torch.nn.functional as F\n", |
|
|
36 |
"import torch.optim as optim" |
|
|
37 |
] |
|
|
38 |
}, |
|
|
39 |
{ |
|
|
40 |
"cell_type": "markdown", |
|
|
41 |
"metadata": {}, |
|
|
42 |
"source": [ |
|
|
43 |
"<p>Here's the description from the paper</p>\n", |
|
|
44 |
"<img src=\"EEGNet.png\" style=\"width: 700px; float:left;\">" |
|
|
45 |
] |
|
|
46 |
}, |
|
|
47 |
{ |
|
|
48 |
"cell_type": "code", |
|
|
49 |
"execution_count": 12, |
|
|
50 |
"metadata": { |
|
|
51 |
"collapsed": false |
|
|
52 |
}, |
|
|
53 |
"outputs": [ |
|
|
54 |
{ |
|
|
55 |
"name": "stdout", |
|
|
56 |
"output_type": "stream", |
|
|
57 |
"text": [ |
|
|
58 |
"Variable containing:\n", |
|
|
59 |
" 0.7338\n", |
|
|
60 |
"[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n", |
|
|
61 |
"\n" |
|
|
62 |
] |
|
|
63 |
} |
|
|
64 |
], |
|
|
65 |
"source": [ |
|
|
66 |
"class EEGNet(nn.Module):\n", |
|
|
67 |
" def __init__(self):\n", |
|
|
68 |
" super(EEGNet, self).__init__()\n", |
|
|
69 |
" self.T = 120\n", |
|
|
70 |
" \n", |
|
|
71 |
" # Layer 1\n", |
|
|
72 |
" self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)\n", |
|
|
73 |
" self.batchnorm1 = nn.BatchNorm2d(16, False)\n", |
|
|
74 |
" \n", |
|
|
75 |
" # Layer 2\n", |
|
|
76 |
" self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))\n", |
|
|
77 |
" self.conv2 = nn.Conv2d(1, 4, (2, 32))\n", |
|
|
78 |
" self.batchnorm2 = nn.BatchNorm2d(4, False)\n", |
|
|
79 |
" self.pooling2 = nn.MaxPool2d(2, 4)\n", |
|
|
80 |
" \n", |
|
|
81 |
" # Layer 3\n", |
|
|
82 |
" self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))\n", |
|
|
83 |
" self.conv3 = nn.Conv2d(4, 4, (8, 4))\n", |
|
|
84 |
" self.batchnorm3 = nn.BatchNorm2d(4, False)\n", |
|
|
85 |
" self.pooling3 = nn.MaxPool2d((2, 4))\n", |
|
|
86 |
" \n", |
|
|
87 |
" # FC Layer\n", |
|
|
88 |
" # NOTE: This dimension will depend on the number of timestamps per sample in your data.\n", |
|
|
89 |
" # I have 120 timepoints. \n", |
|
|
90 |
" self.fc1 = nn.Linear(4*2*7, 1)\n", |
|
|
91 |
" \n", |
|
|
92 |
"\n", |
|
|
93 |
" def forward(self, x):\n", |
|
|
94 |
" # Layer 1\n", |
|
|
95 |
" x = F.elu(self.conv1(x))\n", |
|
|
96 |
" x = self.batchnorm1(x)\n", |
|
|
97 |
" x = F.dropout(x, 0.25)\n", |
|
|
98 |
" x = x.permute(0, 3, 1, 2)\n", |
|
|
99 |
" \n", |
|
|
100 |
" # Layer 2\n", |
|
|
101 |
" x = self.padding1(x)\n", |
|
|
102 |
" x = F.elu(self.conv2(x))\n", |
|
|
103 |
" x = self.batchnorm2(x)\n", |
|
|
104 |
" x = F.dropout(x, 0.25)\n", |
|
|
105 |
" x = self.pooling2(x)\n", |
|
|
106 |
" \n", |
|
|
107 |
" # Layer 3\n", |
|
|
108 |
" x = self.padding2(x)\n", |
|
|
109 |
" x = F.elu(self.conv3(x))\n", |
|
|
110 |
" x = self.batchnorm3(x)\n", |
|
|
111 |
" x = F.dropout(x, 0.25)\n", |
|
|
112 |
" x = self.pooling3(x)\n", |
|
|
113 |
" \n", |
|
|
114 |
" # FC Layer\n", |
|
|
115 |
" x = x.view(-1, 4*2*7)\n", |
|
|
116 |
" x = F.sigmoid(self.fc1(x))\n", |
|
|
117 |
" return x\n", |
|
|
118 |
"\n", |
|
|
119 |
"\n", |
|
|
120 |
"net = EEGNet().cuda(0)\n", |
|
|
121 |
"print net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0)))\n", |
|
|
122 |
"criterion = nn.BCELoss()\n", |
|
|
123 |
"optimizer = optim.Adam(net.parameters())" |
|
|
124 |
] |
|
|
125 |
}, |
|
|
126 |
{ |
|
|
127 |
"cell_type": "markdown", |
|
|
128 |
"metadata": {}, |
|
|
129 |
"source": [ |
|
|
130 |
"#### Evaluate function returns values of different criteria like accuracy, precision etc. \n", |
|
|
131 |
"In case you face memory overflow issues, use batch size to control how many samples get evaluated at one time. Use a batch_size that is a factor of length of samples. This ensures that you won't miss any samples." |
|
|
132 |
] |
|
|
133 |
}, |
|
|
134 |
{ |
|
|
135 |
"cell_type": "code", |
|
|
136 |
"execution_count": 13, |
|
|
137 |
"metadata": { |
|
|
138 |
"collapsed": true |
|
|
139 |
}, |
|
|
140 |
"outputs": [], |
|
|
141 |
"source": [ |
|
|
142 |
"def evaluate(model, X, Y, params = [\"acc\"]):\n", |
|
|
143 |
" results = []\n", |
|
|
144 |
" batch_size = 100\n", |
|
|
145 |
" \n", |
|
|
146 |
" predicted = []\n", |
|
|
147 |
" \n", |
|
|
148 |
" for i in range(len(X)/batch_size):\n", |
|
|
149 |
" s = i*batch_size\n", |
|
|
150 |
" e = i*batch_size+batch_size\n", |
|
|
151 |
" \n", |
|
|
152 |
" inputs = Variable(torch.from_numpy(X[s:e]).cuda(0))\n", |
|
|
153 |
" pred = model(inputs)\n", |
|
|
154 |
" \n", |
|
|
155 |
" predicted.append(pred.data.cpu().numpy())\n", |
|
|
156 |
" \n", |
|
|
157 |
" \n", |
|
|
158 |
" inputs = Variable(torch.from_numpy(X).cuda(0))\n", |
|
|
159 |
" predicted = model(inputs)\n", |
|
|
160 |
" \n", |
|
|
161 |
" predicted = predicted.data.cpu().numpy()\n", |
|
|
162 |
" \n", |
|
|
163 |
" for param in params:\n", |
|
|
164 |
" if param == 'acc':\n", |
|
|
165 |
" results.append(accuracy_score(Y, np.round(predicted)))\n", |
|
|
166 |
" if param == \"auc\":\n", |
|
|
167 |
" results.append(roc_auc_score(Y, predicted))\n", |
|
|
168 |
" if param == \"recall\":\n", |
|
|
169 |
" results.append(recall_score(Y, np.round(predicted)))\n", |
|
|
170 |
" if param == \"precision\":\n", |
|
|
171 |
" results.append(precision_score(Y, np.round(predicted)))\n", |
|
|
172 |
" if param == \"fmeasure\":\n", |
|
|
173 |
" precision = precision_score(Y, np.round(predicted))\n", |
|
|
174 |
" recall = recall_score(Y, np.round(predicted))\n", |
|
|
175 |
" results.append(2*precision*recall/ (precision+recall))\n", |
|
|
176 |
" return results" |
|
|
177 |
] |
|
|
178 |
}, |
|
|
179 |
{ |
|
|
180 |
"cell_type": "markdown", |
|
|
181 |
"metadata": {}, |
|
|
182 |
"source": [ |
|
|
183 |
"#### Generate random data\n", |
|
|
184 |
"\n", |
|
|
185 |
"##### Data format:\n", |
|
|
186 |
"Datatype - float32 (both X and Y) <br>\n", |
|
|
187 |
"X.shape - (#samples, 1, #timepoints, #channels) <br>\n", |
|
|
188 |
"Y.shape - (#samples)" |
|
|
189 |
] |
|
|
190 |
}, |
|
|
191 |
{ |
|
|
192 |
"cell_type": "code", |
|
|
193 |
"execution_count": 14, |
|
|
194 |
"metadata": { |
|
|
195 |
"collapsed": true |
|
|
196 |
}, |
|
|
197 |
"outputs": [], |
|
|
198 |
"source": [ |
|
|
199 |
"X_train = np.random.rand(100, 1, 120, 64).astype('float32') # np.random.rand generates between [0, 1)\n", |
|
|
200 |
"y_train = np.round(np.random.rand(100).astype('float32')) # binary data, so we round it to 0 or 1.\n", |
|
|
201 |
"\n", |
|
|
202 |
"X_val = np.random.rand(100, 1, 120, 64).astype('float32')\n", |
|
|
203 |
"y_val = np.round(np.random.rand(100).astype('float32'))\n", |
|
|
204 |
"\n", |
|
|
205 |
"X_test = np.random.rand(100, 1, 120, 64).astype('float32')\n", |
|
|
206 |
"y_test = np.round(np.random.rand(100).astype('float32'))" |
|
|
207 |
] |
|
|
208 |
}, |
|
|
209 |
{ |
|
|
210 |
"cell_type": "markdown", |
|
|
211 |
"metadata": {}, |
|
|
212 |
"source": [ |
|
|
213 |
"#### Run" |
|
|
214 |
] |
|
|
215 |
}, |
|
|
216 |
{ |
|
|
217 |
"cell_type": "code", |
|
|
218 |
"execution_count": 15, |
|
|
219 |
"metadata": { |
|
|
220 |
"collapsed": false |
|
|
221 |
}, |
|
|
222 |
"outputs": [ |
|
|
223 |
{ |
|
|
224 |
"name": "stdout", |
|
|
225 |
"output_type": "stream", |
|
|
226 |
"text": [ |
|
|
227 |
"\n", |
|
|
228 |
"Epoch 0\n", |
|
|
229 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
230 |
"Training Loss 1.54113572836\n", |
|
|
231 |
"Train - [0.54000000000000004, 0.59178743961352653, 0.70129870129870131]\n", |
|
|
232 |
"Validation - [0.51000000000000001, 0.48539415766306526, 0.67549668874172186]\n", |
|
|
233 |
"Test - [0.5, 0.50319999999999998, 0.66666666666666663]\n", |
|
|
234 |
"\n", |
|
|
235 |
"Epoch 1\n", |
|
|
236 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
237 |
"Training Loss 1.42391115427\n", |
|
|
238 |
"Train - [0.54000000000000004, 0.63888888888888895, 0.70129870129870131]\n", |
|
|
239 |
"Validation - [0.51000000000000001, 0.47458983593437376, 0.67549668874172186]\n", |
|
|
240 |
"Test - [0.5, 0.50439999999999996, 0.66666666666666663]\n", |
|
|
241 |
"\n", |
|
|
242 |
"Epoch 2\n", |
|
|
243 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
244 |
"Training Loss 1.3422973156\n", |
|
|
245 |
"Train - [0.55000000000000004, 0.67995169082125606, 0.70198675496688734]\n", |
|
|
246 |
"Validation - [0.53000000000000003, 0.46898759503801518, 0.68456375838926176]\n", |
|
|
247 |
"Test - [0.51000000000000001, 0.50800000000000001, 0.67114093959731547]\n", |
|
|
248 |
"\n", |
|
|
249 |
"Epoch 3\n", |
|
|
250 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
251 |
"Training Loss 1.28801095486\n", |
|
|
252 |
"Train - [0.63, 0.71054750402576483, 0.73758865248226957]\n", |
|
|
253 |
"Validation - [0.48999999999999999, 0.4601840736294518, 0.63309352517985618]\n", |
|
|
254 |
"Test - [0.52000000000000002, 0.51000000000000001, 0.66666666666666663]\n", |
|
|
255 |
"\n", |
|
|
256 |
"Epoch 4\n", |
|
|
257 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
258 |
"Training Loss 1.25420039892\n", |
|
|
259 |
"Train - [0.68999999999999995, 0.74476650563607083, 0.75590551181102361]\n", |
|
|
260 |
"Validation - [0.42999999999999999, 0.44457783113245297, 0.53658536585365846]\n", |
|
|
261 |
"Test - [0.51000000000000001, 0.51000000000000001, 0.6080000000000001]\n", |
|
|
262 |
"\n", |
|
|
263 |
"Epoch 5\n", |
|
|
264 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
265 |
"Training Loss 1.22989922762\n", |
|
|
266 |
"Train - [0.75, 0.77375201288244766, 0.77876106194690276]\n", |
|
|
267 |
"Validation - [0.46000000000000002, 0.43937575030012005, 0.49056603773584906]\n", |
|
|
268 |
"Test - [0.51000000000000001, 0.50440000000000007, 0.55855855855855863]\n", |
|
|
269 |
"\n", |
|
|
270 |
"Epoch 6\n", |
|
|
271 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
272 |
"Training Loss 1.20727479458\n", |
|
|
273 |
"Train - [0.76000000000000001, 0.79227053140096615, 0.7735849056603773]\n", |
|
|
274 |
"Validation - [0.40999999999999998, 0.43897559023609439, 0.40404040404040403]\n", |
|
|
275 |
"Test - [0.47999999999999998, 0.496, 0.5]\n", |
|
|
276 |
"\n", |
|
|
277 |
"Epoch 7\n", |
|
|
278 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
279 |
"Training Loss 1.18265104294\n", |
|
|
280 |
"Train - [0.76000000000000001, 0.81119162640901765, 0.7735849056603773]\n", |
|
|
281 |
"Validation - [0.40999999999999998, 0.43897559023609445, 0.40404040404040403]\n", |
|
|
282 |
"Test - [0.44, 0.48999999999999999, 0.44]\n", |
|
|
283 |
"\n", |
|
|
284 |
"Epoch 8\n", |
|
|
285 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
286 |
"Training Loss 1.15454357862\n", |
|
|
287 |
"Train - [0.80000000000000004, 0.8248792270531401, 0.81132075471698106]\n", |
|
|
288 |
"Validation - [0.40000000000000002, 0.43497398959583838, 0.40000000000000002]\n", |
|
|
289 |
"Test - [0.48999999999999999, 0.48560000000000003, 0.51428571428571423]\n", |
|
|
290 |
"\n", |
|
|
291 |
"Epoch 9\n", |
|
|
292 |
"['acc', 'auc', 'fmeasure']\n", |
|
|
293 |
"Training Loss 1.12422537804\n", |
|
|
294 |
"Train - [0.81000000000000005, 0.83816425120772953, 0.82568807339449546]\n", |
|
|
295 |
"Validation - [0.40999999999999998, 0.43177270908363347, 0.41584158415841577]\n", |
|
|
296 |
"Test - [0.47999999999999998, 0.4768, 0.52727272727272734]\n" |
|
|
297 |
] |
|
|
298 |
} |
|
|
299 |
], |
|
|
300 |
"source": [ |
|
|
301 |
"batch_size = 32\n", |
|
|
302 |
"\n", |
|
|
303 |
"for epoch in range(10): # loop over the dataset multiple times\n", |
|
|
304 |
" print \"\\nEpoch \", epoch\n", |
|
|
305 |
" \n", |
|
|
306 |
" running_loss = 0.0\n", |
|
|
307 |
" for i in range(len(X_train)/batch_size-1):\n", |
|
|
308 |
" s = i*batch_size\n", |
|
|
309 |
" e = i*batch_size+batch_size\n", |
|
|
310 |
" \n", |
|
|
311 |
" inputs = torch.from_numpy(X_train[s:e])\n", |
|
|
312 |
" labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)\n", |
|
|
313 |
" \n", |
|
|
314 |
" # wrap them in Variable\n", |
|
|
315 |
" inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))\n", |
|
|
316 |
"\n", |
|
|
317 |
" # zero the parameter gradients\n", |
|
|
318 |
" optimizer.zero_grad()\n", |
|
|
319 |
"\n", |
|
|
320 |
" # forward + backward + optimize\n", |
|
|
321 |
" outputs = net(inputs)\n", |
|
|
322 |
" loss = criterion(outputs, labels)\n", |
|
|
323 |
" loss.backward()\n", |
|
|
324 |
" \n", |
|
|
325 |
" \n", |
|
|
326 |
" optimizer.step()\n", |
|
|
327 |
" \n", |
|
|
328 |
" running_loss += loss.data[0]\n", |
|
|
329 |
" \n", |
|
|
330 |
" # Validation accuracy\n", |
|
|
331 |
" params = [\"acc\", \"auc\", \"fmeasure\"]\n", |
|
|
332 |
" print params\n", |
|
|
333 |
" print \"Training Loss \", running_loss\n", |
|
|
334 |
" print \"Train - \", evaluate(net, X_train, y_train, params)\n", |
|
|
335 |
" print \"Validation - \", evaluate(net, X_val, y_val, params)\n", |
|
|
336 |
" print \"Test - \", evaluate(net, X_test, y_test, params)" |
|
|
337 |
] |
|
|
338 |
}, |
|
|
339 |
{ |
|
|
340 |
"cell_type": "code", |
|
|
341 |
"execution_count": null, |
|
|
342 |
"metadata": { |
|
|
343 |
"collapsed": true |
|
|
344 |
}, |
|
|
345 |
"outputs": [], |
|
|
346 |
"source": [] |
|
|
347 |
} |
|
|
348 |
], |
|
|
349 |
"metadata": { |
|
|
350 |
"kernelspec": { |
|
|
351 |
"display_name": "Python 2", |
|
|
352 |
"language": "python", |
|
|
353 |
"name": "python2" |
|
|
354 |
}, |
|
|
355 |
"language_info": { |
|
|
356 |
"codemirror_mode": { |
|
|
357 |
"name": "ipython", |
|
|
358 |
"version": 2 |
|
|
359 |
}, |
|
|
360 |
"file_extension": ".py", |
|
|
361 |
"mimetype": "text/x-python", |
|
|
362 |
"name": "python", |
|
|
363 |
"nbconvert_exporter": "python", |
|
|
364 |
"pygments_lexer": "ipython2", |
|
|
365 |
"version": "2.7.6" |
|
|
366 |
} |
|
|
367 |
}, |
|
|
368 |
"nbformat": 4, |
|
|
369 |
"nbformat_minor": 0 |
|
|
370 |
} |