[c36663]: / braindecode / preprocessing / preprocess.py

Download this file

470 lines (411 with data), 17.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
"""Preprocessors that work on Raw or Epochs objects."""
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
# Lukas Gemein <l.gemein@gmail.com>
# Simon Brandt <simonbrandt@protonmail.com>
# David Sabbagh <dav.sabbagh@gmail.com>
# Bruno Aristimunha <b.aristimunha@gmail.com>
#
# License: BSD (3-clause)
from __future__ import annotations
from warnings import warn
from functools import partial
from collections.abc import Iterable
import sys
import platform
if sys.version_info < (3, 9):
from typing import Callable
else:
from collections.abc import Callable
import numpy as np
from numpy.typing import NDArray
import pandas as pd
from mne import create_info, BaseEpochs
from mne.io import BaseRaw
from joblib import Parallel, delayed
from braindecode.datasets.base import (
BaseConcatDataset,
BaseDataset,
WindowsDataset,
EEGWindowsDataset,
)
from braindecode.datautil.serialization import (
load_concat_dataset,
_check_save_dir_empty,
)
class Preprocessor(object):
"""Preprocessor for an MNE Raw or Epochs object.
Applies the provided preprocessing function to the data of a Raw or Epochs
object.
If the function is provided as a string, the method with that name will be
used (e.g., 'pick_channels', 'filter', etc.).
If it is provided as a callable and `apply_on_array` is True, the
`apply_function` method of Raw and Epochs object will be used to apply the
function on the internal arrays of Raw and Epochs.
If `apply_on_array` is False, the callable must directly modify the Raw or
Epochs object (e.g., by calling its method(s) or modifying its attributes).
Parameters
----------
fn: str or callable
If str, the Raw/Epochs object must have a method with that name.
If callable, directly apply the callable to the object.
apply_on_array : bool
Ignored if `fn` is not a callable. If True, the `apply_function` of Raw
and Epochs object will be used to run `fn` on the underlying arrays
directly. If False, `fn` must directly modify the Raw or Epochs object.
kwargs:
Keyword arguments to be forwarded to the MNE function.
"""
def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
warn("Preprocessing choices with lambda functions cannot be saved.")
if callable(fn) and apply_on_array:
channel_wise = kwargs.pop("channel_wise", False)
picks = kwargs.pop("picks", None)
n_jobs = kwargs.pop("n_jobs", 1)
kwargs = dict(
fun=partial(fn, **kwargs),
channel_wise=channel_wise,
picks=picks,
n_jobs=n_jobs,
)
fn = "apply_function"
self.fn = fn
self.kwargs = kwargs
def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
try:
self._try_apply(raw_or_epochs)
except RuntimeError:
# Maybe the function needs the data to be loaded and the data was
# not loaded yet. Not all MNE functions need data to be loaded,
# most importantly the 'crop' function can be lazily applied
# without preloading data which can make the overall preprocessing
# pipeline substantially faster.
raw_or_epochs.load_data()
self._try_apply(raw_or_epochs)
def _try_apply(self, raw_or_epochs):
if callable(self.fn):
self.fn(raw_or_epochs, **self.kwargs)
else:
if not hasattr(raw_or_epochs, self.fn):
raise AttributeError(f"MNE object does not have a {self.fn} method.")
getattr(raw_or_epochs, self.fn)(**self.kwargs)
def preprocess(
concat_ds: BaseConcatDataset,
preprocessors: list[Preprocessor],
save_dir: str | None = None,
overwrite: bool = False,
n_jobs: int | None = None,
offset: int = 0,
):
"""Apply preprocessors to a concat dataset.
Parameters
----------
concat_ds: BaseConcatDataset
A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
preprocessors: list(Preprocessor)
List of Preprocessor objects to apply to the dataset.
save_dir : str | None
If a string, the preprocessed data will be saved under the specified
directory and the datasets in ``concat_ds`` will be reloaded with
`preload=False`.
overwrite : bool
When `save_dir` is provided, controls whether to delete the old
subdirectories that will be written to under `save_dir`. If False and
the corresponding subdirectories already exist, a ``FileExistsError``
will be raised.
n_jobs : int | None
Number of jobs for parallel execution. See `joblib.Parallel` for
a more detailed explanation.
offset : int
If provided, the integer is added to the id of the dataset in the
concat. This is useful in the setting of very large datasets, where
one dataset has to be processed and saved at a time to account for
its original position.
Returns
-------
BaseConcatDataset:
Preprocessed dataset.
"""
# In case of serialization, make sure directory is available before
# preprocessing
if save_dir is not None and not overwrite:
_check_save_dir_empty(save_dir)
if not isinstance(preprocessors, Iterable):
raise ValueError("preprocessors must be a list of Preprocessor objects.")
for elem in preprocessors:
assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
parallel_processing = (n_jobs is not None) and (n_jobs != 1)
job_prefer = "threads" if platform.system() == "Windows" else None
list_of_ds = Parallel(n_jobs=n_jobs, prefer=job_prefer)(
delayed(_preprocess)(
ds,
i + offset,
preprocessors,
save_dir,
overwrite,
copy_data=(parallel_processing and (save_dir is None)),
)
for i, ds in enumerate(concat_ds.datasets)
)
if save_dir is not None: # Reload datasets and replace in concat_ds
ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
concat_ds_reloaded = load_concat_dataset(
save_dir,
preload=False,
target_name=None,
ids_to_load=ids_to_load,
)
_replace_inplace(concat_ds, concat_ds_reloaded)
else:
if parallel_processing: # joblib made copies
_replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
else: # joblib did not make copies, the
# preprocessing happened in-place
# Recompute cumulative sizes as transforms might have changed them
concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
return concat_ds
def _replace_inplace(concat_ds, new_concat_ds):
"""Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
Parameters
----------
concat_ds : BaseConcatDataset
Dataset to modify inplace.
new_concat_ds : BaseConcatDataset
Dataset to use to modify ``concat_ds``.
"""
if len(concat_ds.datasets) != len(new_concat_ds.datasets):
raise ValueError("Both inputs must have the same length.")
for i in range(len(new_concat_ds.datasets)):
concat_ds.datasets[i] = new_concat_ds.datasets[i]
concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
if hasattr(new_concat_ds, preproc_kwargs_attr):
setattr(
concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
)
def _preprocess(
ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
):
"""Apply preprocessor(s) to Raw or Epochs object.
Parameters
----------
ds: BaseDataset | WindowsDataset
Dataset object to preprocess.
ds_index : int
Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
is None.
preprocessors: list(Preprocessor)
List of preprocessors to apply to the dataset.
save_dir : str | None
If provided, save the preprocessed BaseDataset in the
specified directory.
overwrite : bool
If True, overwrite existing file with the same name.
copy_data : bool
First copy the data in case it is preloaded. Necessary for parallel processing to work.
"""
def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
# Copying the data necessary in some scenarios for parallel processing
# to work when data is in memory (else error about _data not being writeable)
if raw_or_epochs.preload and copy_data:
raw_or_epochs._data = raw_or_epochs._data.copy()
for preproc in preprocessors:
preproc.apply(raw_or_epochs)
if hasattr(ds, "raw"):
if isinstance(ds, EEGWindowsDataset):
warn(
f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
)
_preprocess_raw_or_epochs(ds.raw, preprocessors)
elif hasattr(ds, "windows"):
_preprocess_raw_or_epochs(ds.windows, preprocessors)
else:
raise ValueError(
"Can only preprocess concatenation of BaseDataset or "
"WindowsDataset, with either a `raw` or `windows` attribute."
)
# Store preprocessing keyword arguments in the dataset
_set_preproc_kwargs(ds, preprocessors)
if save_dir is not None:
concat_ds = BaseConcatDataset([ds])
concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
else:
return ds
def _get_preproc_kwargs(preprocessors):
preproc_kwargs = []
for p in preprocessors:
# in case of a mne function, fn is a str, kwargs is a dict
func_name = p.fn
func_kwargs = p.kwargs
# in case of another function
# if apply_on_array=False
if callable(p.fn):
func_name = p.fn.__name__
# if apply_on_array=True
else:
if "fun" in p.fn:
func_name = p.kwargs["fun"].func.__name__
func_kwargs = p.kwargs["fun"].keywords
preproc_kwargs.append((func_name, func_kwargs))
return preproc_kwargs
def _set_preproc_kwargs(ds, preprocessors):
"""Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
Parameters
----------
ds : BaseDataset | WindowsDataset
Dataset in which to record preprocessing keyword arguments.
preprocessors : list
List of preprocessors.
"""
preproc_kwargs = _get_preproc_kwargs(preprocessors)
if isinstance(ds, WindowsDataset):
kind = "window"
if isinstance(ds, EEGWindowsDataset):
kind = "raw"
elif isinstance(ds, BaseDataset):
kind = "raw"
else:
raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
def exponential_moving_standardize(
data: NDArray,
factor_new: float = 0.001,
init_block_size: int | None = None,
eps: float = 1e-4,
):
r"""Perform exponential moving standardization.
Compute the exponental moving mean :math:`m_t` at time `t` as
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
Then, compute exponential moving variance :math:`v_t` at time `t` as
:math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
Finally, standardize the data point :math:`x_t` at time `t` as:
:math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
Parameters
----------
data: np.ndarray (n_channels, n_times)
factor_new: float
init_block_size: int
Standardize data before to this index with regular standardization.
eps: float
Stabilizer for division by zero variance.
Returns
-------
standardized: np.ndarray (n_channels, n_times)
Standardized data.
"""
data = data.T
df = pd.DataFrame(data)
meaned = df.ewm(alpha=factor_new).mean()
demeaned = df - meaned
squared = demeaned * demeaned
square_ewmed = squared.ewm(alpha=factor_new).mean()
standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
standardized = np.array(standardized)
if init_block_size is not None:
i_time_axis = 0
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
eps, init_std
)
standardized[0:init_block_size] = init_block_standardized
return standardized.T
def exponential_moving_demean(
data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
):
r"""Perform exponential moving demeanining.
Compute the exponental moving mean :math:`m_t` at time `t` as
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
Deman the data point :math:`x_t` at time `t` as:
:math:`x'_t=(x_t - m_t)`.
Parameters
----------
data: np.ndarray (n_channels, n_times)
factor_new: float
init_block_size: int
Demean data before to this index with regular demeaning.
Returns
-------
demeaned: np.ndarray (n_channels, n_times)
Demeaned data.
"""
data = data.T
df = pd.DataFrame(data)
meaned = df.ewm(alpha=factor_new).mean()
demeaned = df - meaned
demeaned = np.array(demeaned)
if init_block_size is not None:
i_time_axis = 0
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
return demeaned.T
def filterbank(
raw: BaseRaw,
frequency_bands: list[tuple[float, float]],
drop_original_signals: bool = True,
order_by_frequency_band: bool = False,
**mne_filter_kwargs,
):
"""Applies multiple bandpass filters to the signals in raw. The raw will be
modified in-place and number of channels in raw will be updated to
len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
drop_original_signals).
Parameters
----------
raw: mne.io.Raw
The raw signals to be filtered.
frequency_bands: list(tuple)
The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
drop_original_signals: bool
Whether to drop the original unfiltered signals
order_by_frequency_band: bool
If True will return channels ordered by frequency bands, so if there
are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
mne_filter_kwargs: dict
Keyword arguments for filtering supported by mne.io.Raw.filter().
Please refer to mne for a detailed explanation.
"""
if not frequency_bands:
raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
warn(
"Try to use shorter channel names, since frequency band "
"annotation requires an estimated 4-8 chars depending on the "
"frequency ranges. Will truncate to 15 chars (mne max)."
)
original_ch_names = raw.ch_names
all_filtered = []
for l_freq, h_freq in frequency_bands:
filtered = raw.copy()
filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
# mne automatically changes the highpass/lowpass info values
# when applying filters and channels can't be added if they have
# different such parameters. Not needed when making picks as
# high pass is not modified by filter if pick is specified
ch_names = filtered.info.ch_names
ch_types = filtered.info.get_channel_types()
sampling_freq = filtered.info["sfreq"]
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
filtered.info = info
# add frequency band annotation to channel names
# truncate to a max of 15 characters, since mne does not allow for more
filtered.rename_channels(
{
old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
for old_name in filtered.ch_names
}
)
all_filtered.append(filtered)
raw.add_channels(all_filtered)
if not order_by_frequency_band:
# order channels by name and not by frequency band:
# index the list with a stepsize of the number of channels for each of
# the original channels
chs_by_freq_band = []
for i in range(len(original_ch_names)):
chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
raw.reorder_channels(chs_by_freq_band)
if drop_original_signals:
raw.drop_channels(original_ch_names)