--- a +++ b/brainDecode/deprecated/1 - Two-Classes Classification (BNCI).ipynb @@ -0,0 +1,603 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(160, 15, 2560)\n", + "(160,)\n" + ] + } + ], + "source": [ + "\"\"\"\n", + "Format dataset, we read the file for the desired subject, and parse the data to extract:\n", + "- samplingRate\n", + "- trialLength\n", + "- X, a M x N x K matrix, which stands for trial x chan x samples\n", + " the actual values are 160 x 15 x 2560\n", + "- y, a M vector containing the labels {0,1}\n", + "\n", + "ref:\n", + "Dataset description: https://lampx.tugraz.at/~bci/database/002-2014/description.pdf\n", + "\"\"\"\n", + "\n", + "import scipy.io as sio\n", + "import numpy as np\n", + "\n", + "\n", + "# prepare data containers\n", + "y = []\n", + "X = []\n", + "\n", + "\"\"\"\n", + "trainingFileList = [#'BBCIData/S14T.mat', \n", + " #'BBCIData/S13T.mat', \n", + " #'BBCIData/S12T.mat', \n", + " #'BBCIData/S11T.mat', \n", + " #'BBCIData/S10T.mat', \n", + " #'BBCIData/S09T.mat', \n", + " #'BBCIData/S08T.mat', \n", + " #'BBCIData/S07T.mat', \n", + " #'BBCIData/S06T.mat', \n", + " #'BBCIData/S05T.mat', \n", + " #'BBCIData/S04T.mat', \n", + " #'BBCIData/S03T.mat', \n", + " #'BBCIData/S02T.mat', \n", + " 'BBCIData/S01T.mat']\n", + "\n", + "validationFileList = [#'BBCIData/S14E.mat', \n", + " #'BBCIData/S13E.mat', \n", + " #'BBCIData/S12E.mat', \n", + " #'BBCIData/S11E.mat', \n", + " #'BBCIData/S10E.mat', \n", + " #'BBCIData/S09E.mat', \n", + " #'BBCIData/S08E.mat', \n", + " #'BBCIData/S07E.mat', \n", + " #'BBCIData/S06E.mat', \n", + " #'BBCIData/S05E.mat', \n", + " #'BBCIData/S04E.mat', \n", + " #'BBCIData/S03E.mat', \n", + " #'BBCIData/S02E.mat', \n", + " 'BBCIData/S01E.mat']\n", + "\"\"\"\n", + "\n", + "trainingFileList = ['BBCIData/S08T.mat']\n", + "\n", + "validationFileList = ['BBCIData/S08E.mat']\n", + "\n", + "for i in range(len(trainingFileList)):\n", + " # read file\n", + " d1T = sio.loadmat(trainingFileList[i])\n", + " d1E = sio.loadmat(validationFileList[i])\n", + " \n", + " samplingRate = d1T['data'][0][0][0][0][3][0][0]\n", + " trialLength = 5*samplingRate\n", + "\n", + "\n", + " # run through all training runs\n", + " for run in range(5):\n", + " y.append(d1T['data'][0][run][0][0][2][0]) # labels\n", + " timestamps = d1T['data'][0][run][0][0][1][0] # timestamps\n", + " rawData = d1T['data'][0][run][0][0][0].transpose() # chan x data\n", + "\n", + " # parse out data based on timestamps\n", + " for start in timestamps:\n", + " end = start + trialLength\n", + " X.append(rawData[:,start:end]) #15 x 2560\n", + "\n", + "\n", + " # run through all validation runs (we do not discriminate at this point)\n", + " for run in range(3):\n", + " y.append(d1E['data'][0][run][0][0][2][0]) # labels\n", + " timestamps = d1E['data'][0][run][0][0][1][0] # timestamps\n", + " rawData = d1E['data'][0][run][0][0][0].transpose() # chan x data\n", + "\n", + " # parse out data based on timestamps\n", + " for start in timestamps:\n", + " end = start + trialLength\n", + " X.append(rawData[:,start:end]) #15 x 2560\n", + "\n", + " del rawData\n", + " del d1T\n", + " del d1E\n", + "\n", + "# arrange data into numpy arrays\n", + "# also torch expect float32 for samples\n", + "# and int64 for labels {0,1}\n", + "X = np.array(X).astype(np.float32)\n", + "y = (np.array(y).flatten()-1).astype(np.int64)\n", + "print(X.shape)\n", + "print(y.shape)\n", + "\n", + "# erase unused references\n", + "d1T = []\n", + "d1E = []\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "from braindecode.datautil.signal_target import SignalAndTarget\n", + "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n", + "from torch import nn\n", + "from braindecode.torch_ext.util import set_random_seeds \n", + "from torch import optim\n", + "import torch\n", + "\n", + "idx = np.random.permutation(X.shape[0])\n", + "\n", + "X = X[idx,:,:]\n", + "y = y[idx]\n", + "\n", + "#print(X.shape)\n", + "#print(y.shape)\n", + "\n", + "nb_train_trials = int(np.floor(5/8*X.shape[0]))\n", + "\n", + "\n", + "train_set = SignalAndTarget(X[:nb_train_trials], y=y[:nb_train_trials])\n", + "test_set = SignalAndTarget(X[nb_train_trials:], y=y[nb_train_trials:])\n", + "\n", + "#train_set = SignalAndTarget(X[:nb_train_trials], y=y[:nb_train_trials])\n", + "#test_set = SignalAndTarget(X[nb_train_trials:nb_test_trials], y=y[nb_train_trials:nb_test_trials])\n", + "\n", + "# Set if you want to use GPU\n", + "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n", + "cuda = torch.cuda.is_available()\n", + "set_random_seeds(seed=20170629, cuda=cuda)\n", + "n_classes = 2\n", + "in_chans = train_set.X.shape[1]\n", + "# final_conv_length = auto ensures we only get a single output in the time dimension\n", + "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n", + " input_time_length=train_set.X.shape[2],\n", + " final_conv_length='auto').create_network()\n", + "if cuda:\n", + " model.cuda()\n", + "\n", + "optimizer = optim.Adam(model.parameters())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0\n", + "Train Loss: 23.12267\n", + "Train Accuracy: 51.0%\n", + "Test Loss: 23.45116\n", + "Test Accuracy: 46.7%\n", + "Epoch 1\n", + "Train Loss: 11.42247\n", + "Train Accuracy: 50.0%\n", + "Test Loss: 11.18545\n", + "Test Accuracy: 51.7%\n", + "Epoch 2\n", + "Train Loss: 6.14801\n", + "Train Accuracy: 49.0%\n", + "Test Loss: 5.97394\n", + "Test Accuracy: 50.0%\n", + "Epoch 3\n", + "Train Loss: 7.08030\n", + "Train Accuracy: 51.0%\n", + "Test Loss: 7.17894\n", + "Test Accuracy: 46.7%\n", + "Epoch 4\n", + "Train Loss: 2.78045\n", + "Train Accuracy: 51.0%\n", + "Test Loss: 2.80417\n", + "Test Accuracy: 48.3%\n", + "Epoch 5\n", + "Train Loss: 3.10231\n", + "Train Accuracy: 49.0%\n", + "Test Loss: 3.01706\n", + "Test Accuracy: 51.7%\n", + "Epoch 6\n", + "Train Loss: 0.77492\n", + "Train Accuracy: 59.0%\n", + "Test Loss: 0.73021\n", + "Test Accuracy: 55.0%\n", + "Epoch 7\n", + "Train Loss: 2.08993\n", + "Train Accuracy: 53.0%\n", + "Test Loss: 2.06789\n", + "Test Accuracy: 48.3%\n", + "Epoch 8\n", + "Train Loss: 0.78730\n", + "Train Accuracy: 64.0%\n", + "Test Loss: 0.69451\n", + "Test Accuracy: 58.3%\n", + "Epoch 9\n", + "Train Loss: 0.72214\n", + "Train Accuracy: 65.0%\n", + "Test Loss: 0.62667\n", + "Test Accuracy: 58.3%\n", + "Epoch 10\n", + "Train Loss: 1.04905\n", + "Train Accuracy: 57.0%\n", + "Test Loss: 0.94240\n", + "Test Accuracy: 43.3%\n", + "Epoch 11\n", + "Train Loss: 1.06353\n", + "Train Accuracy: 62.0%\n", + "Test Loss: 0.92665\n", + "Test Accuracy: 63.3%\n", + "Epoch 12\n", + "Train Loss: 0.73815\n", + "Train Accuracy: 67.0%\n", + "Test Loss: 0.61124\n", + "Test Accuracy: 60.0%\n", + "Epoch 13\n", + "Train Loss: 1.21130\n", + "Train Accuracy: 56.0%\n", + "Test Loss: 1.06474\n", + "Test Accuracy: 46.7%\n", + "Epoch 14\n", + "Train Loss: 0.57069\n", + "Train Accuracy: 73.0%\n", + "Test Loss: 0.44313\n", + "Test Accuracy: 56.7%\n", + "Epoch 15\n", + "Train Loss: 0.95924\n", + "Train Accuracy: 62.0%\n", + "Test Loss: 0.83465\n", + "Test Accuracy: 58.3%\n", + "Epoch 16\n", + "Train Loss: 0.61197\n", + "Train Accuracy: 68.0%\n", + "Test Loss: 0.45661\n", + "Test Accuracy: 45.0%\n", + "Epoch 17\n", + "Train Loss: 0.68676\n", + "Train Accuracy: 67.0%\n", + "Test Loss: 0.52255\n", + "Test Accuracy: 51.7%\n", + "Epoch 18\n", + "Train Loss: 0.59275\n", + "Train Accuracy: 71.0%\n", + "Test Loss: 0.49090\n", + "Test Accuracy: 61.7%\n", + "Epoch 19\n", + "Train Loss: 0.46892\n", + "Train Accuracy: 78.0%\n", + "Test Loss: 0.37154\n", + "Test Accuracy: 51.7%\n", + "Epoch 20\n", + "Train Loss: 0.87564\n", + "Train Accuracy: 63.0%\n", + "Test Loss: 0.72043\n", + "Test Accuracy: 51.7%\n", + "Epoch 21\n", + "Train Loss: 0.54479\n", + "Train Accuracy: 73.0%\n", + "Test Loss: 0.42003\n", + "Test Accuracy: 51.7%\n", + "Epoch 22\n", + "Train Loss: 0.47825\n", + "Train Accuracy: 80.0%\n", + "Test Loss: 0.40232\n", + "Test Accuracy: 53.3%\n", + "Epoch 23\n", + "Train Loss: 0.47127\n", + "Train Accuracy: 80.0%\n", + "Test Loss: 0.39606\n", + "Test Accuracy: 55.0%\n", + "Epoch 24\n", + "Train Loss: 0.39154\n", + "Train Accuracy: 83.0%\n", + "Test Loss: 0.29514\n", + "Test Accuracy: 48.3%\n", + "Epoch 25\n", + "Train Loss: 0.42451\n", + "Train Accuracy: 80.0%\n", + "Test Loss: 0.29410\n", + "Test Accuracy: 53.3%\n", + "Epoch 26\n", + "Train Loss: 0.37418\n", + "Train Accuracy: 85.0%\n", + "Test Loss: 0.26566\n", + "Test Accuracy: 53.3%\n", + "Epoch 27\n", + "Train Loss: 0.35942\n", + "Train Accuracy: 84.0%\n", + "Test Loss: 0.27371\n", + "Test Accuracy: 53.3%\n", + "Epoch 28\n", + "Train Loss: 0.37682\n", + "Train Accuracy: 83.0%\n", + "Test Loss: 0.31713\n", + "Test Accuracy: 55.0%\n", + "Epoch 29\n", + "Train Loss: 0.34180\n", + "Train Accuracy: 83.0%\n", + "Test Loss: 0.25558\n", + "Test Accuracy: 55.0%\n", + "Epoch 30\n", + "Train Loss: 0.33804\n", + "Train Accuracy: 87.0%\n", + "Test Loss: 0.23607\n", + "Test Accuracy: 53.3%\n", + "Epoch 31\n", + "Train Loss: 0.31833\n", + "Train Accuracy: 86.0%\n", + "Test Loss: 0.22466\n", + "Test Accuracy: 51.7%\n", + "Epoch 32\n", + "Train Loss: 0.29819\n", + "Train Accuracy: 87.0%\n", + "Test Loss: 0.22489\n", + "Test Accuracy: 50.0%\n", + "Epoch 33\n", + "Train Loss: 0.31271\n", + "Train Accuracy: 89.0%\n", + "Test Loss: 0.25959\n", + "Test Accuracy: 55.0%\n", + "Epoch 34\n", + "Train Loss: 0.29346\n", + "Train Accuracy: 88.0%\n", + "Test Loss: 0.23652\n", + "Test Accuracy: 53.3%\n", + "Epoch 35\n", + "Train Loss: 0.28696\n", + "Train Accuracy: 86.0%\n", + "Test Loss: 0.21594\n", + "Test Accuracy: 51.7%\n", + "Epoch 36\n", + "Train Loss: 0.28489\n", + "Train Accuracy: 87.0%\n", + "Test Loss: 0.20726\n", + "Test Accuracy: 53.3%\n", + "Epoch 37\n", + "Train Loss: 0.25652\n", + "Train Accuracy: 90.0%\n", + "Test Loss: 0.18759\n", + "Test Accuracy: 48.3%\n", + "Epoch 38\n", + "Train Loss: 0.28203\n", + "Train Accuracy: 89.0%\n", + "Test Loss: 0.22545\n", + "Test Accuracy: 51.7%\n", + "Epoch 39\n", + "Train Loss: 0.24893\n", + "Train Accuracy: 93.0%\n", + "Test Loss: 0.18108\n", + "Test Accuracy: 50.0%\n", + "Epoch 40\n", + "Train Loss: 0.26061\n", + "Train Accuracy: 90.0%\n", + "Test Loss: 0.17375\n", + "Test Accuracy: 50.0%\n", + "Epoch 41\n", + "Train Loss: 0.24927\n", + "Train Accuracy: 92.0%\n", + "Test Loss: 0.16775\n", + "Test Accuracy: 51.7%\n", + "Epoch 42\n", + "Train Loss: 0.23456\n", + "Train Accuracy: 92.0%\n", + "Test Loss: 0.16739\n", + "Test Accuracy: 51.7%\n", + "Epoch 43\n", + "Train Loss: 0.23747\n", + "Train Accuracy: 92.0%\n", + "Test Loss: 0.18819\n", + "Test Accuracy: 51.7%\n", + "Epoch 44\n", + "Train Loss: 0.22980\n", + "Train Accuracy: 92.0%\n", + "Test Loss: 0.18161\n", + "Test Accuracy: 51.7%\n", + "Epoch 45\n", + "Train Loss: 0.22144\n", + "Train Accuracy: 94.0%\n", + "Test Loss: 0.16714\n", + "Test Accuracy: 50.0%\n", + "Epoch 46\n", + "Train Loss: 0.23376\n", + "Train Accuracy: 92.0%\n", + "Test Loss: 0.16477\n", + "Test Accuracy: 51.7%\n", + "Epoch 47\n", + "Train Loss: 0.20786\n", + "Train Accuracy: 95.0%\n", + "Test Loss: 0.15308\n", + "Test Accuracy: 50.0%\n", + "Epoch 48\n", + "Train Loss: 0.20483\n", + "Train Accuracy: 94.0%\n", + "Test Loss: 0.16679\n", + "Test Accuracy: 50.0%\n", + "Epoch 49\n", + "Train Loss: 0.19644\n", + "Train Accuracy: 95.0%\n", + "Test Loss: 0.15150\n", + "Test Accuracy: 50.0%\n" + ] + } + ], + "source": [ + "\n", + "from braindecode.torch_ext.util import np_to_var, var_to_np\n", + "from braindecode.datautil.iterators import get_balanced_batches\n", + "import torch.nn.functional as F\n", + "from numpy.random import RandomState\n", + "rng = RandomState(None)\n", + "#rng = RandomState((2017,6,30))\n", + "for i_epoch in range(50):\n", + " i_trials_in_batch = get_balanced_batches(len(train_set.X), rng, shuffle=True,\n", + " batch_size=32)\n", + " # Set model to training mode\n", + " model.train()\n", + " for i_trials in i_trials_in_batch:\n", + " # Have to add empty fourth dimension to X\n", + " batch_X = train_set.X[i_trials][:,:,:,None]\n", + " batch_y = train_set.y[i_trials]\n", + " net_in = np_to_var(batch_X)\n", + " if cuda:\n", + " net_in = net_in.cuda()\n", + " net_target = np_to_var(batch_y)\n", + " if cuda:\n", + " net_target = net_target.cuda()\n", + " # Remove gradients of last backward pass from all parameters\n", + " optimizer.zero_grad()\n", + " # Compute outputs of the network\n", + " outputs = model(net_in)\n", + " # Compute the loss\n", + " loss = F.nll_loss(outputs, net_target)\n", + " # Do the backpropagation\n", + " loss.backward()\n", + " # Update parameters with the optimizer\n", + " optimizer.step()\n", + "\n", + " # Print some statistics each epoch\n", + " model.eval()\n", + " print(\"Epoch {:d}\".format(i_epoch))\n", + " for setname, dataset in (('Train', train_set), ('Test', test_set)):\n", + " i_trials_in_batch = get_balanced_batches(len(dataset.X), rng, batch_size=32, shuffle=False)\n", + " outputs = []\n", + " net_targets = []\n", + " for i_trials in i_trials_in_batch:\n", + " batch_X = train_set.X[i_trials][:,:,:,None]\n", + " batch_y = train_set.y[i_trials]\n", + " \n", + " net_in = np_to_var(batch_X)\n", + " if cuda:\n", + " net_in = net_in.cuda()\n", + " net_target = np_to_var(batch_y)\n", + " if cuda:\n", + " net_target = net_target.cuda()\n", + " net_target = var_to_np(net_target)\n", + " output = var_to_np(model(net_in))\n", + " outputs.append(output)\n", + " net_targets.append(net_target)\n", + " net_targets = np_to_var(np.concatenate(net_targets))\n", + " outputs = np_to_var(np.concatenate(outputs))\n", + " loss = F.nll_loss(outputs, net_targets)\n", + " print(\"{:6s} Loss: {:.5f}\".format(\n", + " setname, float(var_to_np(loss))))\n", + " predicted_labels = np.argmax(var_to_np(outputs), axis=1)\n", + " accuracy = np.mean(dataset.y == predicted_labels)\n", + " print(\"{:6s} Accuracy: {:.1f}%\".format(\n", + " setname, accuracy * 100))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Problem: RAM not big enough\n", + "# next session, manage batches through the hard drive\n", + "# add analytics on training performance\n", + "\n", + "# rough results\n", + "# Subject 1:--------------------------------------------\n", + "# Epoch 49\n", + "# Train Loss: 0.00253\n", + "# Train Accuracy: 100.0%\n", + "# Test Loss: 0.00272\n", + "# Test Accuracy: 60.0%\n", + "\n", + "\n", + "# Subject 2:--------------------------------------------\n", + "# Epoch 49\n", + "# Train Loss: 0.00132\n", + "# Train Accuracy: 100.0%\n", + "# Test Loss: 0.00145\n", + "# Test Accuracy: 45.0%\n", + "\n", + "\n", + "# Subject 3:--------------------------------------------\n", + "# Epoch 27\n", + "# Train Loss: 0.00212\n", + "# Train Accuracy: 100.0%\n", + "# Test Loss: 0.00209\n", + "# Test Accuracy: 43.3%\n", + "\n", + "\n", + "# Subject 4:--------------------------------------------\n", + "# Epoch 34\n", + "# Train Loss: 0.00524\n", + "# Train Accuracy: 100.0%\n", + "# Test Loss: 0.00559\n", + "# Test Accuracy: 46.7%\n", + "\n", + "# Subject 5:--------------------------------------------\n", + "# Epoch 33\n", + "# Train Loss: 0.01777\n", + "# Train Accuracy: 100.0%\n", + "# Test Loss: 0.00994\n", + "# Test Accuracy: 55.0%\n", + "\n", + "# Subject 6:\n", + "# Epoch 49\n", + "# Train Loss: 0.00556\n", + "# Train Accuracy: 100.0%\n", + "# Test Loss: 0.00560\n", + "# Test Accuracy: 56.7%\n", + "\n", + "# Subject 7:\n", + "# Epoch 49\n", + "# Train Loss: 0.00129\n", + "# Train Accuracy: 100.0%\n", + "# Test Loss: 0.00143\n", + "# Test Accuracy: 51.7%\n", + "\n", + "\n", + "# Subject 8:\n", + "# Epoch 49\n", + "# Train Loss: 0.19644\n", + "# Train Accuracy: 95.0%\n", + "# Test Loss: 0.15150\n", + "# Test Accuracy: 50.0%" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}