--- a +++ b/brainDecode/tutorial/TrialWise_Example.ipynb @@ -0,0 +1,421 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using default location ~/mne_data for EEGBCI...\n", + "Downloading http://www.physionet.org/physiobank/database/eegmmidb/S001/S001R05.edf (2.5 MB)\n", + "[........................................] 100.00000 \\ ( 2.5 MB / 2.5 MB, 2.3 MB/s) \n", + "Do you want to set the path:\n", + " /home/fred/mne_data\n", + "as the default EEGBCI dataset path in the mne-python config [y]/n? y\n", + "Downloading http://www.physionet.org/physiobank/database/eegmmidb/S001/S001R06.edf (2.5 MB)\n", + "[........................................] 100.00000 / ( 2.5 MB / 2.5 MB, 15.0 MB/s) \n", + "Downloading http://www.physionet.org/physiobank/database/eegmmidb/S001/S001R09.edf (2.5 MB)\n", + "[........................................] 100.00000 \\ ( 2.5 MB / 2.5 MB, 30.2 MB/s) \n", + "Downloading http://www.physionet.org/physiobank/database/eegmmidb/S001/S001R10.edf (2.5 MB)\n", + "[........................................] 100.00000 - ( 2.5 MB / 2.5 MB, 18.7 MB/s) \n", + "Downloading http://www.physionet.org/physiobank/database/eegmmidb/S001/S001R13.edf (2.5 MB)\n", + "[........................................] 100.00000 \\ ( 2.5 MB / 2.5 MB, 25.3 MB/s) \n", + "Downloading http://www.physionet.org/physiobank/database/eegmmidb/S001/S001R14.edf (2.5 MB)\n", + "[........................................] 100.00000 | ( 2.5 MB / 2.5 MB, 25.5 MB/s) \n", + "Removing orphaned offset at the beginning of the file.\n", + "179 events found\n", + "Events id: [1 2 3]\n", + "90 matching events found\n", + "Loading data for 90 events and 497 original time points ...\n", + "0 bad epochs dropped\n" + ] + } + ], + "source": [ + "\n", + "import mne\n", + "from mne.io import concatenate_raws\n", + "\n", + "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n", + "subject_id = 1\n", + "event_codes = [5,6,9,10,13,14]\n", + "\n", + "# This will download the files if you don't have them yet,\n", + "# and then return the paths to the files.\n", + "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n", + "\n", + "# Load each of the files\n", + "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n", + " for path in physionet_paths]\n", + "\n", + "# Concatenate them\n", + "raw = concatenate_raws(parts)\n", + "\n", + "# Find the events in this dataset\n", + "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n", + "\n", + "# Use only EEG channels\n", + "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n", + " exclude='bads')\n", + "\n", + "# Extract trials, only using EEG channels\n", + "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,baseline=None, preload=True)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "# Convert data from volt to millivolt\n", + "# Pytorch expects float32 for input and int64 for labels.\n", + "X = (epoched.get_data() * 1e6).astype(np.float32)\n", + "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from braindecode.datautil.signal_target import SignalAndTarget\n", + "\n", + "train_set = SignalAndTarget(X[:60], y=y[:60])\n", + "test_set = SignalAndTarget(X[60:], y=y[60:])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n", + "from torch import nn\n", + "from braindecode.torch_ext.util import set_random_seeds\n", + "\n", + "# Set if you want to use GPU\n", + "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n", + "cuda = False\n", + "set_random_seeds(seed=20170629, cuda=cuda)\n", + "n_classes = 2\n", + "in_chans = train_set.X.shape[1]\n", + "# final_conv_length = auto ensures we only get a single output in the time dimension\n", + "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n", + " input_time_length=train_set.X.shape[2],\n", + " final_conv_length='auto').create_network()\n", + "if cuda:\n", + " model.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import optim\n", + "\n", + "optimizer = optim.Adam(model.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0\n", + "Train Loss: 1.17760\n", + "Train Accuracy: 51.7%\n", + "Test Loss: 1.26024\n", + "Test Accuracy: 53.3%\n", + "Epoch 1\n", + "Train Loss: 0.70688\n", + "Train Accuracy: 63.3%\n", + "Test Loss: 0.93679\n", + "Test Accuracy: 56.7%\n", + "Epoch 2\n", + "Train Loss: 0.39426\n", + "Train Accuracy: 86.7%\n", + "Test Loss: 0.65984\n", + "Test Accuracy: 66.7%\n", + "Epoch 3\n", + "Train Loss: 0.33593\n", + "Train Accuracy: 90.0%\n", + "Test Loss: 0.65626\n", + "Test Accuracy: 63.3%\n", + "Epoch 4\n", + "Train Loss: 0.27905\n", + "Train Accuracy: 96.7%\n", + "Test Loss: 0.61044\n", + "Test Accuracy: 56.7%\n", + "Epoch 5\n", + "Train Loss: 0.27319\n", + "Train Accuracy: 96.7%\n", + "Test Loss: 0.59499\n", + "Test Accuracy: 66.7%\n", + "Epoch 6\n", + "Train Loss: 0.26624\n", + "Train Accuracy: 88.3%\n", + "Test Loss: 0.61360\n", + "Test Accuracy: 70.0%\n", + "Epoch 7\n", + "Train Loss: 0.23338\n", + "Train Accuracy: 91.7%\n", + "Test Loss: 0.65626\n", + "Test Accuracy: 73.3%\n", + "Epoch 8\n", + "Train Loss: 0.21903\n", + "Train Accuracy: 90.0%\n", + "Test Loss: 0.68762\n", + "Test Accuracy: 63.3%\n", + "Epoch 9\n", + "Train Loss: 0.18902\n", + "Train Accuracy: 96.7%\n", + "Test Loss: 0.67174\n", + "Test Accuracy: 56.7%\n", + "Epoch 10\n", + "Train Loss: 0.17774\n", + "Train Accuracy: 96.7%\n", + "Test Loss: 0.68129\n", + "Test Accuracy: 56.7%\n", + "Epoch 11\n", + "Train Loss: 0.16396\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.69190\n", + "Test Accuracy: 56.7%\n", + "Epoch 12\n", + "Train Loss: 0.15179\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.69641\n", + "Test Accuracy: 56.7%\n", + "Epoch 13\n", + "Train Loss: 0.15035\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.70514\n", + "Test Accuracy: 60.0%\n", + "Epoch 14\n", + "Train Loss: 0.14255\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.74239\n", + "Test Accuracy: 53.3%\n", + "Epoch 15\n", + "Train Loss: 0.13597\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.77918\n", + "Test Accuracy: 53.3%\n", + "Epoch 16\n", + "Train Loss: 0.13680\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.84160\n", + "Test Accuracy: 56.7%\n", + "Epoch 17\n", + "Train Loss: 0.12191\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.87803\n", + "Test Accuracy: 56.7%\n", + "Epoch 18\n", + "Train Loss: 0.10604\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.90946\n", + "Test Accuracy: 56.7%\n", + "Epoch 19\n", + "Train Loss: 0.10445\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.93717\n", + "Test Accuracy: 56.7%\n", + "Epoch 20\n", + "Train Loss: 0.10497\n", + "Train Accuracy: 98.3%\n", + "Test Loss: 0.95503\n", + "Test Accuracy: 60.0%\n", + "Epoch 21\n", + "Train Loss: 0.10946\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.99562\n", + "Test Accuracy: 60.0%\n", + "Epoch 22\n", + "Train Loss: 0.05721\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.84823\n", + "Test Accuracy: 60.0%\n", + "Epoch 23\n", + "Train Loss: 0.04464\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.78784\n", + "Test Accuracy: 56.7%\n", + "Epoch 24\n", + "Train Loss: 0.04180\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.78525\n", + "Test Accuracy: 56.7%\n", + "Epoch 25\n", + "Train Loss: 0.03730\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.75642\n", + "Test Accuracy: 63.3%\n", + "Epoch 26\n", + "Train Loss: 0.03549\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.71707\n", + "Test Accuracy: 66.7%\n", + "Epoch 27\n", + "Train Loss: 0.03416\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.69796\n", + "Test Accuracy: 66.7%\n", + "Epoch 28\n", + "Train Loss: 0.03205\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.69325\n", + "Test Accuracy: 66.7%\n", + "Epoch 29\n", + "Train Loss: 0.02805\n", + "Train Accuracy: 100.0%\n", + "Test Loss: 0.69243\n", + "Test Accuracy: 63.3%\n" + ] + } + ], + "source": [ + "from braindecode.torch_ext.util import np_to_var, var_to_np\n", + "from braindecode.datautil.iterators import get_balanced_batches\n", + "import torch.nn.functional as F\n", + "from numpy.random import RandomState\n", + "rng = RandomState()\n", + "for i_epoch in range(30):\n", + " i_trials_in_batch = get_balanced_batches(len(train_set.X), rng, shuffle=True,\n", + " batch_size=30)\n", + " # Set model to training mode\n", + " model.train()\n", + " for i_trials in i_trials_in_batch:\n", + " # Have to add empty fourth dimension to X\n", + " batch_X = train_set.X[i_trials][:,:,:,None]\n", + " batch_y = train_set.y[i_trials]\n", + " net_in = np_to_var(batch_X)\n", + " if cuda:\n", + " net_in = net_in.cuda()\n", + " net_target = np_to_var(batch_y)\n", + " if cuda:\n", + " net_target = net_target.cuda()\n", + " # Remove gradients of last backward pass from all parameters\n", + " optimizer.zero_grad()\n", + " # Compute outputs of the network\n", + " outputs = model(net_in)\n", + " # Compute the loss\n", + " loss = F.nll_loss(outputs, net_target)\n", + " # Do the backpropagation\n", + " loss.backward()\n", + " # Update parameters with the optimizer\n", + " optimizer.step()\n", + "\n", + " # Print some statistics each epoch\n", + " model.eval()\n", + " print(\"Epoch {:d}\".format(i_epoch))\n", + " for setname, dataset in (('Train', train_set), ('Test', test_set)):\n", + " # Here, we will use the entire dataset at once, which is still possible\n", + " # for such smaller datasets. Otherwise we would have to use batches.\n", + " net_in = np_to_var(dataset.X[:,:,:,None])\n", + " if cuda:\n", + " net_in = net_in.cuda()\n", + " net_target = np_to_var(dataset.y)\n", + " if cuda:\n", + " net_target = net_target.cuda()\n", + " outputs = model(net_in)\n", + " loss = F.nll_loss(outputs, net_target)\n", + " print(\"{:6s} Loss: {:.5f}\".format(\n", + " setname, float(var_to_np(loss))))\n", + " predicted_labels = np.argmax(var_to_np(outputs), axis=1)\n", + " accuracy = np.mean(dataset.y == predicted_labels)\n", + " print(\"{:6s} Accuracy: {:.1f}%\".format(\n", + " setname, accuracy * 100))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import mne\n", + "import numpy as np\n", + "from mne.io import concatenate_raws\n", + "from braindecode.datautil.signal_target import SignalAndTarget\n", + "\n", + "physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]\n", + "physionet_paths = np.concatenate(physionet_paths)\n", + "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n", + " for path in physionet_paths]\n", + "\n", + "raw = concatenate_raws(parts)\n", + "\n", + "picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n", + " exclude='bads')\n", + "\n", + "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n", + "\n", + "# Read epochs (train will be done only between 1 and 2s)\n", + "# Testing will be done with a running classifier\n", + "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,\n", + " baseline=None, preload=True)\n", + "\n", + "physionet_paths_test = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]\n", + "physionet_paths_test = np.concatenate(physionet_paths_test)\n", + "parts_test = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n", + " for path in physionet_paths_test]\n", + "raw_test = concatenate_raws(parts_test)\n", + "\n", + "picks_test = mne.pick_types(raw_test.info, meg=False, eeg=True, stim=False, eog=False,\n", + " exclude='bads')\n", + "\n", + "events_test = mne.find_events(raw_test, shortest_event=0, stim_channel='STI 014')\n", + "\n", + "# Read epochs (train will be done only between 1 and 2s)\n", + "# Testing will be done with a running classifier\n", + "epoched_test = mne.Epochs(raw_test, events_test, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_test,\n", + " baseline=None, preload=True)\n", + "\n", + "train_X = (epoched.get_data() * 1e6).astype(np.float32)\n", + "train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n", + "test_X = (epoched_test.get_data() * 1e6).astype(np.float32)\n", + "test_y = (epoched_test.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n", + "train_set = SignalAndTarget(train_X, y=train_y)\n", + "test_set = SignalAndTarget(test_X, y=test_y)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}