--- a +++ b/test/unit_tests/test_eegneuralnet.py @@ -0,0 +1,559 @@ +# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com> +# Lukas Gemein <l.gemein@gmail.com> +# +# License: BSD-3 +import logging + +import mne +import numpy as np +import pandas as pd +import pytest +import torch +from scipy.special import softmax +from sklearn.base import clone +from skorch.callbacks import LRScheduler +from skorch.utils import to_tensor +from skorch.helper import SliceDataset +from torch import optim +from torch.nn.functional import nll_loss + +from braindecode import EEGClassifier, EEGRegressor +from braindecode.datasets import BaseConcatDataset, WindowsDataset +from braindecode.models.base import EEGModuleMixin +# from braindecode.models.util import models_dict +from braindecode.models.shallow_fbcsp import ShallowFBCSPNet +from braindecode.training import CroppedLoss +from braindecode.eegneuralnet import _EEGNeuralNet + +class MockDataset(torch.utils.data.Dataset): + def __len__(self): + return 5 + + def __getitem__(self, item): + return torch.rand(3, 10), item % 4 + + +class MockModuleReturnMockedPreds(EEGModuleMixin, torch.nn.Module): + def __init__( + self, + preds, + n_outputs=None, + n_chans=None, + chs_info=None, + n_times=None, + input_window_seconds=None, + sfreq=None, + ): + super().__init__( + n_outputs=n_outputs, + n_chans=n_chans, + chs_info=chs_info, + n_times=n_times, + input_window_seconds=input_window_seconds, + sfreq=sfreq, + ) + self.preds = to_tensor(preds, device="cpu") + self.final_layer = torch.nn.Conv1d(self.n_chans, self.n_outputs, self.n_times) + + def forward(self, x): + return self.preds + + +class MockModuleFinalLayer(MockModuleReturnMockedPreds): + def forward(self, x): + return self.final_layer(x).reshape(x.shape[0], self.n_outputs) + + +@pytest.fixture(params=[EEGClassifier, EEGRegressor]) +def eegneuralnet_cls(request): + return request.param + + +@pytest.fixture +def preds(): + return np.array( + [ + [[0.2, 0.1, 0.1, 0.1], [0.8, 0.9, 0.9, 0.9]], + [[0.2, 0.1, 0.1, 0.1], [0.8, 0.9, 0.9, 0.9]], + [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], + [[1.0, 1.0, 1.0, 0.2], [0.0, 0.0, 0.0, 0.8]], + [[0.9, 0.8, 0.9, 1.0], [0.1, 0.2, 0.1, 0.0]], + ] + ) + + +@pytest.fixture +def Xy(): + dataset = MockDataset() + X, y = zip(*[dataset[i] for i in range(len(dataset))]) + return np.stack(X), np.stack(y) + + +@pytest.fixture +def epochs(Xy): + X, y = Xy + metadata = [(yi, 0, 0, 9) for yi in y] + metadata = pd.DataFrame( + metadata, + columns=["target", "i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"], + ) + return mne.EpochsArray( + X, + info=mne.create_info( + ch_names=[ + "ch1", + "ch2", + "ch3", + ], + sfreq=10, + ch_types="eeg", + ), + metadata=metadata, + ) + + +@pytest.fixture +def windows_dataset_metadata(epochs): + return WindowsDataset( + windows=epochs, + targets_from="metadata", + description={}, + ) + + +@pytest.fixture +def windows_dataset_channels(epochs): + return WindowsDataset( + windows=epochs, + targets_from="channels", + description={}, + ) + +@pytest.fixture +def slice_dataset(windows_dataset_channels): + X = SliceDataset(windows_dataset_channels) + return X + +@pytest.fixture +def concat_dataset_metadata(windows_dataset_metadata): + return BaseConcatDataset([windows_dataset_metadata, windows_dataset_metadata]) + + +@pytest.fixture +def concat_dataset_channels( + windows_dataset_metadata, + windows_dataset_channels, +): + return BaseConcatDataset([windows_dataset_metadata, windows_dataset_channels]) + + +def test_trialwise_predict_and_predict_proba(eegneuralnet_cls): + preds = np.array( + [ + [0.125, 0.875], + [1.0, 0.0], + [0.8, 0.2], + [0.8, 0.2], + [0.9, 0.1], + ] + ) + eegneuralnet = eegneuralnet_cls( + MockModuleReturnMockedPreds, + module__preds=preds, + module__n_outputs=2, + module__n_chans=3, + module__n_times=10, + optimizer=optim.Adam, + batch_size=32, + ) + eegneuralnet.initialize() + target_predict = preds if isinstance(eegneuralnet, EEGRegressor) else preds.argmax(1) + preds = preds if isinstance(eegneuralnet, EEGRegressor) else softmax(preds, axis=1) + np.testing.assert_array_equal(target_predict, eegneuralnet.predict(MockDataset())) + np.testing.assert_allclose(preds, eegneuralnet.predict_proba(MockDataset())) + + +def test_cropped_predict_and_predict_proba(eegneuralnet_cls, preds): + eegneuralnet = eegneuralnet_cls( + MockModuleReturnMockedPreds, + module__preds=preds, + module__n_outputs=4, + module__n_chans=3, + module__n_times=3, + cropped=True, + criterion=CroppedLoss, + criterion__loss_function=nll_loss, + optimizer=optim.Adam, + batch_size=32, + ) + eegneuralnet.initialize() + target_predict = ( + preds.mean(-1) + if isinstance(eegneuralnet, EEGRegressor) + else preds.mean(-1).argmax(1) + ) + # for cropped decoding classifier returns one label for each trial (averaged over all crops) + np.testing.assert_array_equal(target_predict, eegneuralnet.predict(MockDataset())) + # for cropped decoding classifier returns values for each trial (average over all crops) + np.testing.assert_array_equal( + preds.mean(-1), eegneuralnet.predict_proba(MockDataset()) + ) + + +def test_cropped_predict_and_predict_proba_not_aggregate_predictions( + eegneuralnet_cls, preds +): + eegneuralnet = eegneuralnet_cls( + MockModuleReturnMockedPreds, + module__preds=preds, + module__n_outputs=4, + module__n_chans=3, + module__n_times=3, + cropped=True, + criterion=CroppedLoss, + criterion__loss_function=nll_loss, + optimizer=optim.Adam, + batch_size=32, + aggregate_predictions=False, + ) + eegneuralnet.initialize() + target_predict = preds if isinstance(eegneuralnet, EEGRegressor) else preds.argmax(1) + np.testing.assert_array_equal(target_predict, eegneuralnet.predict(MockDataset())) + np.testing.assert_array_equal(preds, eegneuralnet.predict_proba(MockDataset())) + + +def test_predict_trials(eegneuralnet_cls, preds): + eegneuralnet = eegneuralnet_cls( + MockModuleReturnMockedPreds, + module__preds=preds, + module__n_outputs=4, + module__n_chans=3, + module__n_times=3, + cropped=False, + criterion=CroppedLoss, + criterion__loss_function=nll_loss, + optimizer=optim.Adam, + batch_size=32, + ) + eegneuralnet.initialize() + with pytest.warns( + UserWarning, + match="This method was designed to predict " "trials in cropped mode.", + ): + eegneuralnet.predict_trials(MockDataset(), return_targets=False) + + +def test_clonable(eegneuralnet_cls, preds): + eegneuralnet = eegneuralnet_cls( + MockModuleReturnMockedPreds, + module__preds=preds, + module__n_outputs=4, + module__n_chans=3, + module__n_times=3, + cropped=False, + callbacks=[ + "accuracy", + ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=1)), + ], + criterion=CroppedLoss, + criterion__loss_function=nll_loss, + optimizer=optim.Adam, + batch_size=32, + ) + clone(eegneuralnet) + eegneuralnet.initialize() + clone(eegneuralnet) + + +def test_set_signal_params_numpy(eegneuralnet_cls, preds, Xy): + X, y = Xy + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(X, y=y) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == (1 if isinstance(net, EEGRegressor) else 4) + + +def test_set_signal_params_epochs(eegneuralnet_cls, preds, epochs): + y = epochs.metadata.target.values + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(epochs, y=y) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == (1 if isinstance(net, EEGRegressor) else 4) + assert net.module_.chs_info == epochs.info["chs"] + assert net.module_.input_window_seconds == 10 / 10 + assert net.module_.sfreq == 10 + + +def test_set_signal_params_torch_ds(eegneuralnet_cls, preds): + n_outputs = 1 if eegneuralnet_cls == EEGRegressor else 4 + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + module__n_outputs=n_outputs, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(MockDataset(), y=None) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == n_outputs + + +def test_set_signal_params_windows_ds_metadata( + eegneuralnet_cls, preds, windows_dataset_metadata +): + n_outputs = 1 if eegneuralnet_cls == EEGRegressor else 4 + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(windows_dataset_metadata, y=None) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == n_outputs + + +def test_set_signal_params_windows_ds_channels( + eegneuralnet_cls, preds, windows_dataset_channels +): + n_outputs = 1 if eegneuralnet_cls == EEGRegressor else 4 + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + module__n_outputs=n_outputs, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(windows_dataset_channels, y=None) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == n_outputs + + +def test_set_signal_params_concat_ds_metadata( + eegneuralnet_cls, preds, concat_dataset_metadata +): + n_outputs = 1 if eegneuralnet_cls == EEGRegressor else 4 + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(concat_dataset_metadata, y=None) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == n_outputs + + +def test_set_signal_params_concat_ds_channels( + eegneuralnet_cls, preds, concat_dataset_channels +): + n_outputs = 1 if eegneuralnet_cls == EEGRegressor else 4 + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + module__n_outputs=n_outputs, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(concat_dataset_channels, y=None) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == n_outputs + + +def test_initialized_module(eegneuralnet_cls, preds, caplog, Xy): + X, y = Xy + module = MockModuleReturnMockedPreds( + preds=preds, + n_outputs=12, + n_chans=12, + n_times=12, + ) + net = eegneuralnet_cls( + module, + cropped=False, + max_epochs=1, + train_split=None, + ) + with caplog.at_level(logging.INFO): + net.fit(X, y) + assert "The module passed is already initialized" in caplog.text + assert net.module_.n_outputs == 12 + assert net.module_.n_chans == 12 + assert net.module_.n_times == 12 + + +# @pytest.mark.parametrize("model_name,model_cls", models_dict.items()) +def test_module_name(eegneuralnet_cls): + net = eegneuralnet_cls( + "ShallowFBCSPNet", + module__n_outputs=4, + module__n_chans=3, + module__n_times=100, + cropped=False, + ) + net.initialize() + assert isinstance(net.module_, ShallowFBCSPNet) + + +def test_unknown_module_name(eegneuralnet_cls): + net = eegneuralnet_cls( + "InexistentModel", + ) + with pytest.raises(ValueError) as excinfo: + net.initialize() + assert "Unknown model name" in str(excinfo.value) + + +def test_EEGRegressor_drop_index(Xy): + # Initialize EEGRegressor with drop_index=False + X, y = Xy + + net = EEGRegressor( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + + # Test if the iterator is returned when drop_index is False + iterator = net.get_iterator(X, training=False, drop_index=False) + assert isinstance(iterator, torch.utils.data.DataLoader) + + +def test_EEGRegressor_get_n_outputs(preds): + # Initialize EEGRegressor + + eeg_regressor = EEGRegressor( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=2, + train_split=None, + max_epochs=1, + ) + + # Test _get_n_outputs method + assert eeg_regressor._get_n_outputs(y=None, + classes=None) is None + assert eeg_regressor._get_n_outputs(y=np.array([0, 1, 2, 3, 4]), + classes=None) == 1 + assert eeg_regressor._get_n_outputs(y=np.array( + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + classes=None) == 5 + + +def test_EEGRegressor_predict_trials(Xy, preds): + X, y = Xy + # Initialize EEGRegressor + eeg_regressor = EEGRegressor( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=2, + train_split=None, + max_epochs=1, + ) + + eeg_regressor.fit(X, y=y) + + preds, targets = eeg_regressor.predict_trials(X, + return_targets=True) + assert preds.shape[0] == len(X) + assert np.array_equal(targets, np.concatenate([X[i][1] + for i in range(len(X))])) +from braindecode.eegneuralnet import CroppedTrialEpochScoring + +class ConcreteEEGNeuralNet(_EEGNeuralNet): + def _get_n_outputs(self, y, classes): + # provide your implementation here + pass + +@pytest.fixture() +def net(): + net = ConcreteEEGNeuralNet(module="EEGNetv4", criterion=CroppedTrialEpochScoring, + cropped=False, max_epochs=1, train_split=None, + n_times=5) + return net + + +def test_cropped_trial_epoch_scoring(net): + train_scoring = net._parse_str_callback('accuracy')[0][1] + valid_scoring = net._parse_str_callback('accuracy')[1][1] + + assert train_scoring.on_train is True + assert train_scoring.name == 'train_accuracy' + + assert valid_scoring.on_train is False + assert valid_scoring.name == 'valid_accuracy' + +def test_get_n_outputs(): + with pytest.raises(TypeError): + _EEGNeuralNet()._get_n_outputs(None, None) + + +def test_set_signal_params_slice_dataset( + eegneuralnet_cls, preds, slice_dataset +): + if eegneuralnet_cls != EEGClassifier: + n_outputs = 1 + y_train = np.array([0, 1, 2, 3, 4]) + else: + n_outputs = 5 + y_train = np.array([0, 1, 2, 3, 4]) # dummy values for y_train + + net = eegneuralnet_cls( + MockModuleFinalLayer, + module__preds=preds, + cropped=False, + optimizer=optim.Adam, + batch_size=32, + train_split=None, + max_epochs=1, + ) + net.fit(slice_dataset, y=y_train) + assert net.module_.n_times == 10 + assert net.module_.n_chans == 3 + assert net.module_.n_outputs == n_outputs