[ac138c]: / examples / advanced_training / plot_data_augmentation.py

Download this file

286 lines (232 with data), 9.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
"""
Data Augmentation on BCIC IV 2a Dataset
=======================================
This tutorial shows how to train EEG deep models with data augmentation. It
follows the trial-wise decoding example and also illustrates the effect of a
transform on the input signals.
.. contents:: This example covers:
:local:
:depth: 2
"""
# Authors: Simon Brandt <simonbrandt@protonmail.com>
# Cédric Rommel <cedric.rommel@inria.fr>
#
# License: BSD (3-clause)
######################################################################
# Loading and preprocessing the dataset
# -------------------------------------
#
# Loading
# ~~~~~~~
from skorch.helper import predefined_split
from skorch.callbacks import LRScheduler
from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset
subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
######################################################################
# Preprocessing
# ~~~~~~~~~~~~~
#
from braindecode.preprocessing import (
exponential_moving_standardize,
preprocess,
Preprocessor,
)
from numpy import multiply
low_cut_hz = 4.0 # low cut frequency for filtering
high_cut_hz = 38.0 # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6
preprocessors = [
Preprocessor("pick_types", eeg=True, meg=False, stim=False), # Keep EEG sensors
Preprocessor(lambda data: multiply(data, factor)), # Convert from V to uV
Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter
Preprocessor(
exponential_moving_standardize, # Exponential moving standardization
factor_new=factor_new,
init_block_size=init_block_size,
),
]
preprocess(dataset, preprocessors, n_jobs=-1)
######################################################################
# Extracting windows
# ~~~~~~~~~~~~~~~~~~
#
from braindecode.preprocessing import create_windows_from_events
trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
# Create windows using braindecode function for this. It needs parameters to
# define how trials should be used.
windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
trial_stop_offset_samples=0,
preload=True,
)
######################################################################
# Split dataset into train and valid
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
splitted = windows_dataset.split("session")
train_set = splitted["0train"] # Session train
valid_set = splitted["1test"] # Session evaluation
######################################################################
# Defining a Transform
# --------------------
#
# Data can be manipulated by transforms, which are callable objects. A
# transform is usually handled by a custom data loader, but can also be called
# directly on input data, as demonstrated below for illutrative purposes.
#
# First, we need to define a Transform. Here we chose the FrequencyShift, which
# randomly translates all frequencies within a given range.
from braindecode.augmentation import FrequencyShift
transform = FrequencyShift(
probability=1.0, # defines the probability of actually modifying the input
sfreq=sfreq,
max_delta_freq=2.0, # the frequency shifts are sampled now between -2 and 2 Hz
)
######################################################################
# Manipulating one session and visualizing the transformed data
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
#
# Next, let us augment one session to show the resulting frequency shift. The
# data of an mne Epoch is used here to make usage of mne functions.
import torch
import numpy as np
X = np.stack([X for X, y, i in train_set.datasets[0]])
# This allows to apply the transform with a fixed shift (10 Hz) for
# visualization instead of sampling the shift randomly between -2 and 2 Hz
X_tr, _ = transform.operation(torch.as_tensor(X).float(), None, 10.0, sfreq) # type: ignore[has-type]
######################################################################
# The psd of the transformed session has now been shifted by 10 Hz, as one can
# see on the psd plot.
import mne
import matplotlib.pyplot as plt
def plot_psd(data, axis, label, color):
psds, freqs = mne.time_frequency.psd_array_multitaper(
data, sfreq=sfreq, fmin=0.1, fmax=100
)
psds = 10.0 * np.log10(psds)
psds_mean = psds.mean(0).mean(0)
axis.plot(freqs, psds_mean, color=color, label=label)
_, ax = plt.subplots()
plot_psd(X, ax, "original", "k")
plot_psd(X_tr.numpy(), ax, "shifted", "r")
ax.set(
title="Multitaper PSD (gradiometers)",
xlabel="Frequency (Hz)",
ylabel="Power Spectral Density (dB)",
)
ax.legend()
plt.show()
######################################################################
# Training a model with data augmentation
# ---------------------------------------
#
# Now that we know how to instantiate ``Transforms``, it is time to learn how
# to use them to train a model and try to improve its generalization power.
# Let's first create a model.
#
# Create model
# ~~~~~~~~~~~~
#
# The model to be trained is defined as usual.
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)
n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_channels = train_set[0][0].shape[0]
n_times = train_set[0][0].shape[1]
model = ShallowFBCSPNet(
n_chans=n_channels,
n_outputs=n_classes,
n_times=n_times,
final_conv_length="auto",
)
######################################################################
# Create an EEGClassifier with the desired augmentation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In order to train with data augmentation, a custom data loader can be
# for the training. Multiple transforms can be passed to it and will be applied
# sequentially to the batched data within the ``AugmentedDataLoader`` object.
from braindecode.augmentation import AugmentedDataLoader, SignFlip
freq_shift = FrequencyShift(
probability=0.5,
sfreq=sfreq,
max_delta_freq=2.0, # the frequency shifts are sampled now between -2 and 2 Hz
)
sign_flip = SignFlip(probability=0.1)
transforms = [freq_shift, sign_flip]
# Send model to GPU
if cuda:
model.cuda()
######################################################################
# The model is now trained as in the trial-wise example. The
# ``AugmentedDataLoader`` is used as the train iterator and the list of
# transforms are passed as arguments.
lr = 0.0625 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 4
clf = EEGClassifier(
model,
iterator_train=AugmentedDataLoader, # This tells EEGClassifier to use a custom DataLoader
iterator_train__transforms=transforms, # This sets the augmentations to use
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.AdamW,
train_split=predefined_split(valid_set), # using valid_set for validation
optimizer__lr=lr,
optimizer__weight_decay=weight_decay,
batch_size=batch_size,
callbacks=[
"accuracy",
("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
],
device=device,
classes=classes,
)
# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
######################################################################
# Manually composing Transforms
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# It would be equivalent (although more verbose) to pass to ``EEGClassifier`` a
# composition of the same transforms:
from braindecode.augmentation import Compose
composed_transforms = Compose(transforms=transforms)
######################################################################
# Setting the data augmentation at the Dataset level
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Also note that it is also possible for most of the transforms to pass them
# directly to the WindowsDataset object through the `transform` argument, as
# most commonly done in other libraries. However, it is advised to use the
# ``AugmentedDataLoader`` as above, as it is compatible with all transforms and
# can be more efficient.
train_set.transform = composed_transforms