[5889be]: / EEGNet-PyTorch.ipynb

Download this file

371 lines (370 with data), 12.4 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Written by, \n",
    "Sriram Ravindran, sriram@ucsd.edu\n",
    "\n",
    "Original paper - https://arxiv.org/abs/1611.08024\n",
    "\n",
    "Please reach out to me if you spot an error.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p>Here's the description from the paper</p>\n",
    "<img src=\"EEGNet.png\" style=\"width: 700px; float:left;\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      " 0.7338\n",
      "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class EEGNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(EEGNet, self).__init__()\n",
    "        self.T = 120\n",
    "        \n",
    "        # Layer 1\n",
    "        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)\n",
    "        self.batchnorm1 = nn.BatchNorm2d(16, False)\n",
    "        \n",
    "        # Layer 2\n",
    "        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))\n",
    "        self.conv2 = nn.Conv2d(1, 4, (2, 32))\n",
    "        self.batchnorm2 = nn.BatchNorm2d(4, False)\n",
    "        self.pooling2 = nn.MaxPool2d(2, 4)\n",
    "        \n",
    "        # Layer 3\n",
    "        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))\n",
    "        self.conv3 = nn.Conv2d(4, 4, (8, 4))\n",
    "        self.batchnorm3 = nn.BatchNorm2d(4, False)\n",
    "        self.pooling3 = nn.MaxPool2d((2, 4))\n",
    "        \n",
    "        # FC Layer\n",
    "        # NOTE: This dimension will depend on the number of timestamps per sample in your data.\n",
    "        # I have 120 timepoints. \n",
    "        self.fc1 = nn.Linear(4*2*7, 1)\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        # Layer 1\n",
    "        x = F.elu(self.conv1(x))\n",
    "        x = self.batchnorm1(x)\n",
    "        x = F.dropout(x, 0.25)\n",
    "        x = x.permute(0, 3, 1, 2)\n",
    "        \n",
    "        # Layer 2\n",
    "        x = self.padding1(x)\n",
    "        x = F.elu(self.conv2(x))\n",
    "        x = self.batchnorm2(x)\n",
    "        x = F.dropout(x, 0.25)\n",
    "        x = self.pooling2(x)\n",
    "        \n",
    "        # Layer 3\n",
    "        x = self.padding2(x)\n",
    "        x = F.elu(self.conv3(x))\n",
    "        x = self.batchnorm3(x)\n",
    "        x = F.dropout(x, 0.25)\n",
    "        x = self.pooling3(x)\n",
    "        \n",
    "        # FC Layer\n",
    "        x = x.view(-1, 4*2*7)\n",
    "        x = F.sigmoid(self.fc1(x))\n",
    "        return x\n",
    "\n",
    "\n",
    "net = EEGNet().cuda(0)\n",
    "print net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0)))\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = optim.Adam(net.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluate function returns values of different criteria like accuracy, precision etc. \n",
    "In case you face memory overflow issues, use batch size to control how many samples get evaluated at one time. Use a batch_size that is a factor of length of samples. This ensures that you won't miss any samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate(model, X, Y, params = [\"acc\"]):\n",
    "    results = []\n",
    "    batch_size = 100\n",
    "    \n",
    "    predicted = []\n",
    "    \n",
    "    for i in range(len(X)/batch_size):\n",
    "        s = i*batch_size\n",
    "        e = i*batch_size+batch_size\n",
    "        \n",
    "        inputs = Variable(torch.from_numpy(X[s:e]).cuda(0))\n",
    "        pred = model(inputs)\n",
    "        \n",
    "        predicted.append(pred.data.cpu().numpy())\n",
    "        \n",
    "        \n",
    "    inputs = Variable(torch.from_numpy(X).cuda(0))\n",
    "    predicted = model(inputs)\n",
    "    \n",
    "    predicted = predicted.data.cpu().numpy()\n",
    "    \n",
    "    for param in params:\n",
    "        if param == 'acc':\n",
    "            results.append(accuracy_score(Y, np.round(predicted)))\n",
    "        if param == \"auc\":\n",
    "            results.append(roc_auc_score(Y, predicted))\n",
    "        if param == \"recall\":\n",
    "            results.append(recall_score(Y, np.round(predicted)))\n",
    "        if param == \"precision\":\n",
    "            results.append(precision_score(Y, np.round(predicted)))\n",
    "        if param == \"fmeasure\":\n",
    "            precision = precision_score(Y, np.round(predicted))\n",
    "            recall = recall_score(Y, np.round(predicted))\n",
    "            results.append(2*precision*recall/ (precision+recall))\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generate random data\n",
    "\n",
    "##### Data format:\n",
    "Datatype - float32 (both X and Y) <br>\n",
    "X.shape - (#samples, 1, #timepoints,  #channels) <br>\n",
    "Y.shape - (#samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_train = np.random.rand(100, 1, 120, 64).astype('float32') # np.random.rand generates between [0, 1)\n",
    "y_train = np.round(np.random.rand(100).astype('float32')) # binary data, so we round it to 0 or 1.\n",
    "\n",
    "X_val = np.random.rand(100, 1, 120, 64).astype('float32')\n",
    "y_val = np.round(np.random.rand(100).astype('float32'))\n",
    "\n",
    "X_test = np.random.rand(100, 1, 120, 64).astype('float32')\n",
    "y_test = np.round(np.random.rand(100).astype('float32'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch  0\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.54113572836\n",
      "Train -  [0.54000000000000004, 0.59178743961352653, 0.70129870129870131]\n",
      "Validation -  [0.51000000000000001, 0.48539415766306526, 0.67549668874172186]\n",
      "Test -  [0.5, 0.50319999999999998, 0.66666666666666663]\n",
      "\n",
      "Epoch  1\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.42391115427\n",
      "Train -  [0.54000000000000004, 0.63888888888888895, 0.70129870129870131]\n",
      "Validation -  [0.51000000000000001, 0.47458983593437376, 0.67549668874172186]\n",
      "Test -  [0.5, 0.50439999999999996, 0.66666666666666663]\n",
      "\n",
      "Epoch  2\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.3422973156\n",
      "Train -  [0.55000000000000004, 0.67995169082125606, 0.70198675496688734]\n",
      "Validation -  [0.53000000000000003, 0.46898759503801518, 0.68456375838926176]\n",
      "Test -  [0.51000000000000001, 0.50800000000000001, 0.67114093959731547]\n",
      "\n",
      "Epoch  3\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.28801095486\n",
      "Train -  [0.63, 0.71054750402576483, 0.73758865248226957]\n",
      "Validation -  [0.48999999999999999, 0.4601840736294518, 0.63309352517985618]\n",
      "Test -  [0.52000000000000002, 0.51000000000000001, 0.66666666666666663]\n",
      "\n",
      "Epoch  4\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.25420039892\n",
      "Train -  [0.68999999999999995, 0.74476650563607083, 0.75590551181102361]\n",
      "Validation -  [0.42999999999999999, 0.44457783113245297, 0.53658536585365846]\n",
      "Test -  [0.51000000000000001, 0.51000000000000001, 0.6080000000000001]\n",
      "\n",
      "Epoch  5\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.22989922762\n",
      "Train -  [0.75, 0.77375201288244766, 0.77876106194690276]\n",
      "Validation -  [0.46000000000000002, 0.43937575030012005, 0.49056603773584906]\n",
      "Test -  [0.51000000000000001, 0.50440000000000007, 0.55855855855855863]\n",
      "\n",
      "Epoch  6\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.20727479458\n",
      "Train -  [0.76000000000000001, 0.79227053140096615, 0.7735849056603773]\n",
      "Validation -  [0.40999999999999998, 0.43897559023609439, 0.40404040404040403]\n",
      "Test -  [0.47999999999999998, 0.496, 0.5]\n",
      "\n",
      "Epoch  7\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.18265104294\n",
      "Train -  [0.76000000000000001, 0.81119162640901765, 0.7735849056603773]\n",
      "Validation -  [0.40999999999999998, 0.43897559023609445, 0.40404040404040403]\n",
      "Test -  [0.44, 0.48999999999999999, 0.44]\n",
      "\n",
      "Epoch  8\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.15454357862\n",
      "Train -  [0.80000000000000004, 0.8248792270531401, 0.81132075471698106]\n",
      "Validation -  [0.40000000000000002, 0.43497398959583838, 0.40000000000000002]\n",
      "Test -  [0.48999999999999999, 0.48560000000000003, 0.51428571428571423]\n",
      "\n",
      "Epoch  9\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.12422537804\n",
      "Train -  [0.81000000000000005, 0.83816425120772953, 0.82568807339449546]\n",
      "Validation -  [0.40999999999999998, 0.43177270908363347, 0.41584158415841577]\n",
      "Test -  [0.47999999999999998, 0.4768, 0.52727272727272734]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "\n",
    "for epoch in range(10):  # loop over the dataset multiple times\n",
    "    print \"\\nEpoch \", epoch\n",
    "    \n",
    "    running_loss = 0.0\n",
    "    for i in range(len(X_train)/batch_size-1):\n",
    "        s = i*batch_size\n",
    "        e = i*batch_size+batch_size\n",
    "        \n",
    "        inputs = torch.from_numpy(X_train[s:e])\n",
    "        labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)\n",
    "        \n",
    "        # wrap them in Variable\n",
    "        inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        \n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.data[0]\n",
    "    \n",
    "    # Validation accuracy\n",
    "    params = [\"acc\", \"auc\", \"fmeasure\"]\n",
    "    print params\n",
    "    print \"Training Loss \", running_loss\n",
    "    print \"Train - \", evaluate(net, X_train, y_train, params)\n",
    "    print \"Validation - \", evaluate(net, X_val, y_val, params)\n",
    "    print \"Test - \", evaluate(net, X_test, y_test, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}