# Authors: Maciej Sliwowski
# Robin Tibor Schirrmeister
#
# License: BSD-3
import sys
import mne
import numpy as np
import pytest
import torch
from mne.io import concatenate_raws
from skorch.helper import predefined_split
from torch import optim
from braindecode import EEGClassifier
from braindecode.datasets.xy import create_from_X_y
from braindecode.models import ShallowFBCSPNet
from braindecode.training.losses import CroppedLoss
from braindecode.util import set_random_seeds
@pytest.mark.skipif(sys.version_info != (3, 7), reason="Only for Python 3.7")
def test_cropped_decoding():
# 5,6,7,10,13,14 are codes for executed and imagined hands/feet
subject_id = 1
event_codes = [5, 6, 9, 10, 13, 14]
# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(
subject_id, event_codes, update_path=False
)
# Load each of the files
parts = [
mne.io.read_raw_edf(path, preload=True, stim_channel="auto", verbose="WARNING")
for path in physionet_paths
]
# Concatenate them
raw = concatenate_raws(parts)
# Find the events in this dataset
events, _ = mne.events_from_annotations(raw)
# Use only EEG channels
eeg_channel_inds = mne.pick_types(
raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads"
)
# Extract trials, only using EEG channels
epoched = mne.Epochs(
raw,
events,
dict(hands=2, feet=3),
tmin=1,
tmax=4.1,
proj=False,
picks=eeg_channel_inds,
baseline=None,
preload=True,
)
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:, 2] - 2).astype(np.int64) # 2,3 -> 0,1
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)
# This will determine how many crops are processed in parallel
input_window_samples = 450
n_classes = 2
in_chans = X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(
in_chans=in_chans,
n_classes=n_classes,
input_window_samples=input_window_samples,
final_conv_length=12,
)
model.to_dense_prediction_model()
if cuda:
model.cuda()
# Perform forward pass to determine how many outputs per input
n_preds_per_input = model.get_output_shape()[2]
train_set = create_from_X_y(
X[:60],
y[:60],
drop_last_window=False,
sfreq=100,
window_size_samples=input_window_samples,
window_stride_samples=n_preds_per_input,
)
valid_set = create_from_X_y(
X[60:],
y[60:],
drop_last_window=False,
sfreq=100,
window_size_samples=input_window_samples,
window_stride_samples=n_preds_per_input,
)
train_split = predefined_split(valid_set)
clf = EEGClassifier(
model,
cropped=True,
criterion=CroppedLoss,
criterion__loss_function=torch.nn.functional.nll_loss,
optimizer=optim.Adam,
train_split=train_split,
batch_size=32,
callbacks=["accuracy"],
classes=[0, 1],
)
clf.fit(train_set, y=None, epochs=4)
np.testing.assert_allclose(
clf.history[:, "train_loss"],
np.array([1.391054, 1.278387, 1.086732, 1.111006]),
rtol=1e-3,
atol=1e-4,
)
np.testing.assert_allclose(
clf.history[:, "valid_loss"],
np.array([2.24272, 0.891798, 0.741147, 0.933025]),
rtol=1e-3,
atol=1e-3,
)
np.testing.assert_allclose(
clf.history[:, "train_accuracy"],
np.array([0.5, 0.516667, 0.6, 0.533333]),
rtol=1e-3,
atol=1e-4,
)
np.testing.assert_allclose(
clf.history[:, "valid_accuracy"],
np.array([0.466667, 0.533333, 0.6, 0.6]),
rtol=1e-3,
atol=1e-4,
)