--- a +++ b/EEGNet-PyTorch.ipynb @@ -0,0 +1,370 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "Written by, \n", + "Sriram Ravindran, sriram@ucsd.edu\n", + "\n", + "Original paper - https://arxiv.org/abs/1611.08024\n", + "\n", + "Please reach out to me if you spot an error.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.autograd import Variable\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<p>Here's the description from the paper</p>\n", + "<img src=\"EEGNet.png\" style=\"width: 700px; float:left;\">" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Variable containing:\n", + " 0.7338\n", + "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n", + "\n" + ] + } + ], + "source": [ + "class EEGNet(nn.Module):\n", + " def __init__(self):\n", + " super(EEGNet, self).__init__()\n", + " self.T = 120\n", + " \n", + " # Layer 1\n", + " self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)\n", + " self.batchnorm1 = nn.BatchNorm2d(16, False)\n", + " \n", + " # Layer 2\n", + " self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))\n", + " self.conv2 = nn.Conv2d(1, 4, (2, 32))\n", + " self.batchnorm2 = nn.BatchNorm2d(4, False)\n", + " self.pooling2 = nn.MaxPool2d(2, 4)\n", + " \n", + " # Layer 3\n", + " self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))\n", + " self.conv3 = nn.Conv2d(4, 4, (8, 4))\n", + " self.batchnorm3 = nn.BatchNorm2d(4, False)\n", + " self.pooling3 = nn.MaxPool2d((2, 4))\n", + " \n", + " # FC Layer\n", + " # NOTE: This dimension will depend on the number of timestamps per sample in your data.\n", + " # I have 120 timepoints. \n", + " self.fc1 = nn.Linear(4*2*7, 1)\n", + " \n", + "\n", + " def forward(self, x):\n", + " # Layer 1\n", + " x = F.elu(self.conv1(x))\n", + " x = self.batchnorm1(x)\n", + " x = F.dropout(x, 0.25)\n", + " x = x.permute(0, 3, 1, 2)\n", + " \n", + " # Layer 2\n", + " x = self.padding1(x)\n", + " x = F.elu(self.conv2(x))\n", + " x = self.batchnorm2(x)\n", + " x = F.dropout(x, 0.25)\n", + " x = self.pooling2(x)\n", + " \n", + " # Layer 3\n", + " x = self.padding2(x)\n", + " x = F.elu(self.conv3(x))\n", + " x = self.batchnorm3(x)\n", + " x = F.dropout(x, 0.25)\n", + " x = self.pooling3(x)\n", + " \n", + " # FC Layer\n", + " x = x.view(-1, 4*2*7)\n", + " x = F.sigmoid(self.fc1(x))\n", + " return x\n", + "\n", + "\n", + "net = EEGNet().cuda(0)\n", + "print net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0)))\n", + "criterion = nn.BCELoss()\n", + "optimizer = optim.Adam(net.parameters())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Evaluate function returns values of different criteria like accuracy, precision etc. \n", + "In case you face memory overflow issues, use batch size to control how many samples get evaluated at one time. Use a batch_size that is a factor of length of samples. This ensures that you won't miss any samples." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def evaluate(model, X, Y, params = [\"acc\"]):\n", + " results = []\n", + " batch_size = 100\n", + " \n", + " predicted = []\n", + " \n", + " for i in range(len(X)/batch_size):\n", + " s = i*batch_size\n", + " e = i*batch_size+batch_size\n", + " \n", + " inputs = Variable(torch.from_numpy(X[s:e]).cuda(0))\n", + " pred = model(inputs)\n", + " \n", + " predicted.append(pred.data.cpu().numpy())\n", + " \n", + " \n", + " inputs = Variable(torch.from_numpy(X).cuda(0))\n", + " predicted = model(inputs)\n", + " \n", + " predicted = predicted.data.cpu().numpy()\n", + " \n", + " for param in params:\n", + " if param == 'acc':\n", + " results.append(accuracy_score(Y, np.round(predicted)))\n", + " if param == \"auc\":\n", + " results.append(roc_auc_score(Y, predicted))\n", + " if param == \"recall\":\n", + " results.append(recall_score(Y, np.round(predicted)))\n", + " if param == \"precision\":\n", + " results.append(precision_score(Y, np.round(predicted)))\n", + " if param == \"fmeasure\":\n", + " precision = precision_score(Y, np.round(predicted))\n", + " recall = recall_score(Y, np.round(predicted))\n", + " results.append(2*precision*recall/ (precision+recall))\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate random data\n", + "\n", + "##### Data format:\n", + "Datatype - float32 (both X and Y) <br>\n", + "X.shape - (#samples, 1, #timepoints, #channels) <br>\n", + "Y.shape - (#samples)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "X_train = np.random.rand(100, 1, 120, 64).astype('float32') # np.random.rand generates between [0, 1)\n", + "y_train = np.round(np.random.rand(100).astype('float32')) # binary data, so we round it to 0 or 1.\n", + "\n", + "X_val = np.random.rand(100, 1, 120, 64).astype('float32')\n", + "y_val = np.round(np.random.rand(100).astype('float32'))\n", + "\n", + "X_test = np.random.rand(100, 1, 120, 64).astype('float32')\n", + "y_test = np.round(np.random.rand(100).astype('float32'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Run" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 0\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.54113572836\n", + "Train - [0.54000000000000004, 0.59178743961352653, 0.70129870129870131]\n", + "Validation - [0.51000000000000001, 0.48539415766306526, 0.67549668874172186]\n", + "Test - [0.5, 0.50319999999999998, 0.66666666666666663]\n", + "\n", + "Epoch 1\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.42391115427\n", + "Train - [0.54000000000000004, 0.63888888888888895, 0.70129870129870131]\n", + "Validation - [0.51000000000000001, 0.47458983593437376, 0.67549668874172186]\n", + "Test - [0.5, 0.50439999999999996, 0.66666666666666663]\n", + "\n", + "Epoch 2\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.3422973156\n", + "Train - [0.55000000000000004, 0.67995169082125606, 0.70198675496688734]\n", + "Validation - [0.53000000000000003, 0.46898759503801518, 0.68456375838926176]\n", + "Test - [0.51000000000000001, 0.50800000000000001, 0.67114093959731547]\n", + "\n", + "Epoch 3\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.28801095486\n", + "Train - [0.63, 0.71054750402576483, 0.73758865248226957]\n", + "Validation - [0.48999999999999999, 0.4601840736294518, 0.63309352517985618]\n", + "Test - [0.52000000000000002, 0.51000000000000001, 0.66666666666666663]\n", + "\n", + "Epoch 4\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.25420039892\n", + "Train - [0.68999999999999995, 0.74476650563607083, 0.75590551181102361]\n", + "Validation - [0.42999999999999999, 0.44457783113245297, 0.53658536585365846]\n", + "Test - [0.51000000000000001, 0.51000000000000001, 0.6080000000000001]\n", + "\n", + "Epoch 5\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.22989922762\n", + "Train - [0.75, 0.77375201288244766, 0.77876106194690276]\n", + "Validation - [0.46000000000000002, 0.43937575030012005, 0.49056603773584906]\n", + "Test - [0.51000000000000001, 0.50440000000000007, 0.55855855855855863]\n", + "\n", + "Epoch 6\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.20727479458\n", + "Train - [0.76000000000000001, 0.79227053140096615, 0.7735849056603773]\n", + "Validation - [0.40999999999999998, 0.43897559023609439, 0.40404040404040403]\n", + "Test - [0.47999999999999998, 0.496, 0.5]\n", + "\n", + "Epoch 7\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.18265104294\n", + "Train - [0.76000000000000001, 0.81119162640901765, 0.7735849056603773]\n", + "Validation - [0.40999999999999998, 0.43897559023609445, 0.40404040404040403]\n", + "Test - [0.44, 0.48999999999999999, 0.44]\n", + "\n", + "Epoch 8\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.15454357862\n", + "Train - [0.80000000000000004, 0.8248792270531401, 0.81132075471698106]\n", + "Validation - [0.40000000000000002, 0.43497398959583838, 0.40000000000000002]\n", + "Test - [0.48999999999999999, 0.48560000000000003, 0.51428571428571423]\n", + "\n", + "Epoch 9\n", + "['acc', 'auc', 'fmeasure']\n", + "Training Loss 1.12422537804\n", + "Train - [0.81000000000000005, 0.83816425120772953, 0.82568807339449546]\n", + "Validation - [0.40999999999999998, 0.43177270908363347, 0.41584158415841577]\n", + "Test - [0.47999999999999998, 0.4768, 0.52727272727272734]\n" + ] + } + ], + "source": [ + "batch_size = 32\n", + "\n", + "for epoch in range(10): # loop over the dataset multiple times\n", + " print \"\\nEpoch \", epoch\n", + " \n", + " running_loss = 0.0\n", + " for i in range(len(X_train)/batch_size-1):\n", + " s = i*batch_size\n", + " e = i*batch_size+batch_size\n", + " \n", + " inputs = torch.from_numpy(X_train[s:e])\n", + " labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)\n", + " \n", + " # wrap them in Variable\n", + " inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))\n", + "\n", + " # zero the parameter gradients\n", + " optimizer.zero_grad()\n", + "\n", + " # forward + backward + optimize\n", + " outputs = net(inputs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " \n", + " \n", + " optimizer.step()\n", + " \n", + " running_loss += loss.data[0]\n", + " \n", + " # Validation accuracy\n", + " params = [\"acc\", \"auc\", \"fmeasure\"]\n", + " print params\n", + " print \"Training Loss \", running_loss\n", + " print \"Train - \", evaluate(net, X_train, y_train, params)\n", + " print \"Validation - \", evaluate(net, X_val, y_val, params)\n", + " print \"Test - \", evaluate(net, X_test, y_test, params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}