a b/braindecode/preprocessing/preprocess.py
1
"""Preprocessors that work on Raw or Epochs objects."""
2
3
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
#          Lukas Gemein <l.gemein@gmail.com>
5
#          Simon Brandt <simonbrandt@protonmail.com>
6
#          David Sabbagh <dav.sabbagh@gmail.com>
7
#          Bruno Aristimunha <b.aristimunha@gmail.com>
8
#
9
# License: BSD (3-clause)
10
11
from __future__ import annotations
12
from warnings import warn
13
from functools import partial
14
from collections.abc import Iterable
15
import sys
16
import platform
17
18
if sys.version_info < (3, 9):
19
    from typing import Callable
20
else:
21
    from collections.abc import Callable
22
23
import numpy as np
24
from numpy.typing import NDArray
25
import pandas as pd
26
from mne import create_info, BaseEpochs
27
from mne.io import BaseRaw
28
from joblib import Parallel, delayed
29
30
from braindecode.datasets.base import (
31
    BaseConcatDataset,
32
    BaseDataset,
33
    WindowsDataset,
34
    EEGWindowsDataset,
35
)
36
from braindecode.datautil.serialization import (
37
    load_concat_dataset,
38
    _check_save_dir_empty,
39
)
40
41
42
class Preprocessor(object):
43
    """Preprocessor for an MNE Raw or Epochs object.
44
45
    Applies the provided preprocessing function to the data of a Raw or Epochs
46
    object.
47
    If the function is provided as a string, the method with that name will be
48
    used (e.g., 'pick_channels', 'filter', etc.).
49
    If it is provided as a callable and `apply_on_array` is True, the
50
    `apply_function` method of Raw and Epochs object will be used to apply the
51
    function on the internal arrays of Raw and Epochs.
52
    If `apply_on_array` is False, the callable must directly modify the Raw or
53
    Epochs object (e.g., by calling its method(s) or modifying its attributes).
54
55
    Parameters
56
    ----------
57
    fn: str or callable
58
        If str, the Raw/Epochs object must have a method with that name.
59
        If callable, directly apply the callable to the object.
60
    apply_on_array : bool
61
        Ignored if `fn` is not a callable. If True, the `apply_function` of Raw
62
        and Epochs object will be used to run `fn` on the underlying arrays
63
        directly. If False, `fn` must directly modify the Raw or Epochs object.
64
    kwargs:
65
        Keyword arguments to be forwarded to the MNE function.
66
    """
67
68
    def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
69
        if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
70
            warn("Preprocessing choices with lambda functions cannot be saved.")
71
        if callable(fn) and apply_on_array:
72
            channel_wise = kwargs.pop("channel_wise", False)
73
            picks = kwargs.pop("picks", None)
74
            n_jobs = kwargs.pop("n_jobs", 1)
75
            kwargs = dict(
76
                fun=partial(fn, **kwargs),
77
                channel_wise=channel_wise,
78
                picks=picks,
79
                n_jobs=n_jobs,
80
            )
81
            fn = "apply_function"
82
        self.fn = fn
83
        self.kwargs = kwargs
84
85
    def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
86
        try:
87
            self._try_apply(raw_or_epochs)
88
        except RuntimeError:
89
            # Maybe the function needs the data to be loaded and the data was
90
            # not loaded yet. Not all MNE functions need data to be loaded,
91
            # most importantly the 'crop' function can be lazily applied
92
            # without preloading data which can make the overall preprocessing
93
            # pipeline substantially faster.
94
            raw_or_epochs.load_data()
95
            self._try_apply(raw_or_epochs)
96
97
    def _try_apply(self, raw_or_epochs):
98
        if callable(self.fn):
99
            self.fn(raw_or_epochs, **self.kwargs)
100
        else:
101
            if not hasattr(raw_or_epochs, self.fn):
102
                raise AttributeError(f"MNE object does not have a {self.fn} method.")
103
            getattr(raw_or_epochs, self.fn)(**self.kwargs)
104
105
106
def preprocess(
107
    concat_ds: BaseConcatDataset,
108
    preprocessors: list[Preprocessor],
109
    save_dir: str | None = None,
110
    overwrite: bool = False,
111
    n_jobs: int | None = None,
112
    offset: int = 0,
113
):
114
    """Apply preprocessors to a concat dataset.
115
116
    Parameters
117
    ----------
118
    concat_ds: BaseConcatDataset
119
        A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
120
    preprocessors: list(Preprocessor)
121
        List of Preprocessor objects to apply to the dataset.
122
    save_dir : str | None
123
        If a string, the preprocessed data will be saved under the specified
124
        directory and the datasets in ``concat_ds`` will be reloaded with
125
        `preload=False`.
126
    overwrite : bool
127
        When `save_dir` is provided, controls whether to delete the old
128
        subdirectories that will be written to under `save_dir`. If False and
129
        the corresponding subdirectories already exist, a ``FileExistsError``
130
        will be raised.
131
    n_jobs : int | None
132
        Number of jobs for parallel execution. See `joblib.Parallel` for
133
        a more detailed explanation.
134
    offset : int
135
        If provided, the integer is added to the id of the dataset in the
136
        concat. This is useful in the setting of very large datasets, where
137
        one dataset has to be processed and saved at a time to account for
138
        its original position.
139
140
    Returns
141
    -------
142
    BaseConcatDataset:
143
        Preprocessed dataset.
144
    """
145
    # In case of serialization, make sure directory is available before
146
    # preprocessing
147
    if save_dir is not None and not overwrite:
148
        _check_save_dir_empty(save_dir)
149
150
    if not isinstance(preprocessors, Iterable):
151
        raise ValueError("preprocessors must be a list of Preprocessor objects.")
152
    for elem in preprocessors:
153
        assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
154
155
    parallel_processing = (n_jobs is not None) and (n_jobs != 1)
156
157
    job_prefer = "threads" if platform.system() == "Windows" else None
158
    list_of_ds = Parallel(n_jobs=n_jobs, prefer=job_prefer)(
159
        delayed(_preprocess)(
160
            ds,
161
            i + offset,
162
            preprocessors,
163
            save_dir,
164
            overwrite,
165
            copy_data=(parallel_processing and (save_dir is None)),
166
        )
167
        for i, ds in enumerate(concat_ds.datasets)
168
    )
169
170
    if save_dir is not None:  # Reload datasets and replace in concat_ds
171
        ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
172
        concat_ds_reloaded = load_concat_dataset(
173
            save_dir,
174
            preload=False,
175
            target_name=None,
176
            ids_to_load=ids_to_load,
177
        )
178
        _replace_inplace(concat_ds, concat_ds_reloaded)
179
    else:
180
        if parallel_processing:  # joblib made copies
181
            _replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
182
        else:  # joblib did not make copies, the
183
            # preprocessing happened in-place
184
            # Recompute cumulative sizes as transforms might have changed them
185
            concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
186
187
    return concat_ds
188
189
190
def _replace_inplace(concat_ds, new_concat_ds):
191
    """Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
192
193
    Parameters
194
    ----------
195
    concat_ds : BaseConcatDataset
196
        Dataset to modify inplace.
197
    new_concat_ds : BaseConcatDataset
198
        Dataset to use to modify ``concat_ds``.
199
    """
200
    if len(concat_ds.datasets) != len(new_concat_ds.datasets):
201
        raise ValueError("Both inputs must have the same length.")
202
    for i in range(len(new_concat_ds.datasets)):
203
        concat_ds.datasets[i] = new_concat_ds.datasets[i]
204
205
    concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
206
    preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
207
    if hasattr(new_concat_ds, preproc_kwargs_attr):
208
        setattr(
209
            concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
210
        )
211
212
213
def _preprocess(
214
    ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
215
):
216
    """Apply preprocessor(s) to Raw or Epochs object.
217
218
    Parameters
219
    ----------
220
    ds: BaseDataset | WindowsDataset
221
        Dataset object to preprocess.
222
    ds_index : int
223
        Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
224
        is None.
225
    preprocessors: list(Preprocessor)
226
        List of preprocessors to apply to the dataset.
227
    save_dir : str | None
228
        If provided, save the preprocessed BaseDataset in the
229
        specified directory.
230
    overwrite : bool
231
        If True, overwrite existing file with the same name.
232
    copy_data : bool
233
        First copy the data in case it is preloaded. Necessary for parallel processing to work.
234
    """
235
236
    def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
237
        # Copying the data necessary in some scenarios for parallel processing
238
        # to work when data is in memory (else error about _data not being writeable)
239
        if raw_or_epochs.preload and copy_data:
240
            raw_or_epochs._data = raw_or_epochs._data.copy()
241
        for preproc in preprocessors:
242
            preproc.apply(raw_or_epochs)
243
244
    if hasattr(ds, "raw"):
245
        if isinstance(ds, EEGWindowsDataset):
246
            warn(
247
                f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
248
            )
249
        _preprocess_raw_or_epochs(ds.raw, preprocessors)
250
    elif hasattr(ds, "windows"):
251
        _preprocess_raw_or_epochs(ds.windows, preprocessors)
252
    else:
253
        raise ValueError(
254
            "Can only preprocess concatenation of BaseDataset or "
255
            "WindowsDataset, with either a `raw` or `windows` attribute."
256
        )
257
258
    # Store preprocessing keyword arguments in the dataset
259
    _set_preproc_kwargs(ds, preprocessors)
260
261
    if save_dir is not None:
262
        concat_ds = BaseConcatDataset([ds])
263
        concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
264
    else:
265
        return ds
266
267
268
def _get_preproc_kwargs(preprocessors):
269
    preproc_kwargs = []
270
    for p in preprocessors:
271
        # in case of a mne function, fn is a str, kwargs is a dict
272
        func_name = p.fn
273
        func_kwargs = p.kwargs
274
        # in case of another function
275
        # if apply_on_array=False
276
        if callable(p.fn):
277
            func_name = p.fn.__name__
278
        # if apply_on_array=True
279
        else:
280
            if "fun" in p.fn:
281
                func_name = p.kwargs["fun"].func.__name__
282
                func_kwargs = p.kwargs["fun"].keywords
283
        preproc_kwargs.append((func_name, func_kwargs))
284
    return preproc_kwargs
285
286
287
def _set_preproc_kwargs(ds, preprocessors):
288
    """Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
289
290
    Parameters
291
    ----------
292
    ds : BaseDataset | WindowsDataset
293
        Dataset in which to record preprocessing keyword arguments.
294
    preprocessors : list
295
        List of preprocessors.
296
    """
297
    preproc_kwargs = _get_preproc_kwargs(preprocessors)
298
    if isinstance(ds, WindowsDataset):
299
        kind = "window"
300
    if isinstance(ds, EEGWindowsDataset):
301
        kind = "raw"
302
    elif isinstance(ds, BaseDataset):
303
        kind = "raw"
304
    else:
305
        raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
306
    setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
307
308
309
def exponential_moving_standardize(
310
    data: NDArray,
311
    factor_new: float = 0.001,
312
    init_block_size: int | None = None,
313
    eps: float = 1e-4,
314
):
315
    r"""Perform exponential moving standardization.
316
317
    Compute the exponental moving mean :math:`m_t` at time `t` as
318
    :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
319
320
    Then, compute exponential moving variance :math:`v_t` at time `t` as
321
    :math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
322
323
    Finally, standardize the data point :math:`x_t` at time `t` as:
324
    :math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
325
326
327
    Parameters
328
    ----------
329
    data: np.ndarray (n_channels, n_times)
330
    factor_new: float
331
    init_block_size: int
332
        Standardize data before to this index with regular standardization.
333
    eps: float
334
        Stabilizer for division by zero variance.
335
336
    Returns
337
    -------
338
    standardized: np.ndarray (n_channels, n_times)
339
        Standardized data.
340
    """
341
    data = data.T
342
    df = pd.DataFrame(data)
343
    meaned = df.ewm(alpha=factor_new).mean()
344
    demeaned = df - meaned
345
    squared = demeaned * demeaned
346
    square_ewmed = squared.ewm(alpha=factor_new).mean()
347
    standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
348
    standardized = np.array(standardized)
349
    if init_block_size is not None:
350
        i_time_axis = 0
351
        init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
352
        init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
353
        init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
354
            eps, init_std
355
        )
356
        standardized[0:init_block_size] = init_block_standardized
357
    return standardized.T
358
359
360
def exponential_moving_demean(
361
    data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
362
):
363
    r"""Perform exponential moving demeanining.
364
365
    Compute the exponental moving mean :math:`m_t` at time `t` as
366
    :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
367
368
    Deman the data point :math:`x_t` at time `t` as:
369
    :math:`x'_t=(x_t - m_t)`.
370
371
    Parameters
372
    ----------
373
    data: np.ndarray (n_channels, n_times)
374
    factor_new: float
375
    init_block_size: int
376
        Demean data before to this index with regular demeaning.
377
378
    Returns
379
    -------
380
    demeaned: np.ndarray (n_channels, n_times)
381
        Demeaned data.
382
    """
383
    data = data.T
384
    df = pd.DataFrame(data)
385
    meaned = df.ewm(alpha=factor_new).mean()
386
    demeaned = df - meaned
387
    demeaned = np.array(demeaned)
388
    if init_block_size is not None:
389
        i_time_axis = 0
390
        init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
391
        demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
392
    return demeaned.T
393
394
395
def filterbank(
396
    raw: BaseRaw,
397
    frequency_bands: list[tuple[float, float]],
398
    drop_original_signals: bool = True,
399
    order_by_frequency_band: bool = False,
400
    **mne_filter_kwargs,
401
):
402
    """Applies multiple bandpass filters to the signals in raw. The raw will be
403
    modified in-place and number of channels in raw will be updated to
404
    len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
405
    drop_original_signals).
406
407
    Parameters
408
    ----------
409
    raw: mne.io.Raw
410
        The raw signals to be filtered.
411
    frequency_bands: list(tuple)
412
        The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
413
    drop_original_signals: bool
414
        Whether to drop the original unfiltered signals
415
    order_by_frequency_band: bool
416
        If True will return channels ordered by frequency bands, so if there
417
        are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
418
        channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
419
        will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
420
    mne_filter_kwargs: dict
421
        Keyword arguments for filtering supported by mne.io.Raw.filter().
422
        Please refer to mne for a detailed explanation.
423
    """
424
    if not frequency_bands:
425
        raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
426
    if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
427
        warn(
428
            "Try to use shorter channel names, since frequency band "
429
            "annotation requires an estimated 4-8 chars depending on the "
430
            "frequency ranges. Will truncate to 15 chars (mne max)."
431
        )
432
    original_ch_names = raw.ch_names
433
    all_filtered = []
434
    for l_freq, h_freq in frequency_bands:
435
        filtered = raw.copy()
436
        filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
437
        # mne automatically changes the highpass/lowpass info values
438
        # when applying filters and channels can't be added if they have
439
        # different such parameters. Not needed when making picks as
440
        # high pass is not modified by filter if pick is specified
441
442
        ch_names = filtered.info.ch_names
443
        ch_types = filtered.info.get_channel_types()
444
        sampling_freq = filtered.info["sfreq"]
445
446
        info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
447
448
        filtered.info = info
449
450
        # add frequency band annotation to channel names
451
        # truncate to a max of 15 characters, since mne does not allow for more
452
        filtered.rename_channels(
453
            {
454
                old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
455
                for old_name in filtered.ch_names
456
            }
457
        )
458
        all_filtered.append(filtered)
459
    raw.add_channels(all_filtered)
460
    if not order_by_frequency_band:
461
        # order channels by name and not by frequency band:
462
        # index the list with a stepsize of the number of channels for each of
463
        # the original channels
464
        chs_by_freq_band = []
465
        for i in range(len(original_ch_names)):
466
            chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
467
        raw.reorder_channels(chs_by_freq_band)
468
    if drop_original_signals:
469
        raw.drop_channels(original_ch_names)