Switch to side-by-side view

--- a
+++ b/braindecode/preprocessing/preprocess.py
@@ -0,0 +1,469 @@
+"""Preprocessors that work on Raw or Epochs objects."""
+
+# Authors: Hubert Banville <hubert.jbanville@gmail.com>
+#          Lukas Gemein <l.gemein@gmail.com>
+#          Simon Brandt <simonbrandt@protonmail.com>
+#          David Sabbagh <dav.sabbagh@gmail.com>
+#          Bruno Aristimunha <b.aristimunha@gmail.com>
+#
+# License: BSD (3-clause)
+
+from __future__ import annotations
+from warnings import warn
+from functools import partial
+from collections.abc import Iterable
+import sys
+import platform
+
+if sys.version_info < (3, 9):
+    from typing import Callable
+else:
+    from collections.abc import Callable
+
+import numpy as np
+from numpy.typing import NDArray
+import pandas as pd
+from mne import create_info, BaseEpochs
+from mne.io import BaseRaw
+from joblib import Parallel, delayed
+
+from braindecode.datasets.base import (
+    BaseConcatDataset,
+    BaseDataset,
+    WindowsDataset,
+    EEGWindowsDataset,
+)
+from braindecode.datautil.serialization import (
+    load_concat_dataset,
+    _check_save_dir_empty,
+)
+
+
+class Preprocessor(object):
+    """Preprocessor for an MNE Raw or Epochs object.
+
+    Applies the provided preprocessing function to the data of a Raw or Epochs
+    object.
+    If the function is provided as a string, the method with that name will be
+    used (e.g., 'pick_channels', 'filter', etc.).
+    If it is provided as a callable and `apply_on_array` is True, the
+    `apply_function` method of Raw and Epochs object will be used to apply the
+    function on the internal arrays of Raw and Epochs.
+    If `apply_on_array` is False, the callable must directly modify the Raw or
+    Epochs object (e.g., by calling its method(s) or modifying its attributes).
+
+    Parameters
+    ----------
+    fn: str or callable
+        If str, the Raw/Epochs object must have a method with that name.
+        If callable, directly apply the callable to the object.
+    apply_on_array : bool
+        Ignored if `fn` is not a callable. If True, the `apply_function` of Raw
+        and Epochs object will be used to run `fn` on the underlying arrays
+        directly. If False, `fn` must directly modify the Raw or Epochs object.
+    kwargs:
+        Keyword arguments to be forwarded to the MNE function.
+    """
+
+    def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
+        if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
+            warn("Preprocessing choices with lambda functions cannot be saved.")
+        if callable(fn) and apply_on_array:
+            channel_wise = kwargs.pop("channel_wise", False)
+            picks = kwargs.pop("picks", None)
+            n_jobs = kwargs.pop("n_jobs", 1)
+            kwargs = dict(
+                fun=partial(fn, **kwargs),
+                channel_wise=channel_wise,
+                picks=picks,
+                n_jobs=n_jobs,
+            )
+            fn = "apply_function"
+        self.fn = fn
+        self.kwargs = kwargs
+
+    def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
+        try:
+            self._try_apply(raw_or_epochs)
+        except RuntimeError:
+            # Maybe the function needs the data to be loaded and the data was
+            # not loaded yet. Not all MNE functions need data to be loaded,
+            # most importantly the 'crop' function can be lazily applied
+            # without preloading data which can make the overall preprocessing
+            # pipeline substantially faster.
+            raw_or_epochs.load_data()
+            self._try_apply(raw_or_epochs)
+
+    def _try_apply(self, raw_or_epochs):
+        if callable(self.fn):
+            self.fn(raw_or_epochs, **self.kwargs)
+        else:
+            if not hasattr(raw_or_epochs, self.fn):
+                raise AttributeError(f"MNE object does not have a {self.fn} method.")
+            getattr(raw_or_epochs, self.fn)(**self.kwargs)
+
+
+def preprocess(
+    concat_ds: BaseConcatDataset,
+    preprocessors: list[Preprocessor],
+    save_dir: str | None = None,
+    overwrite: bool = False,
+    n_jobs: int | None = None,
+    offset: int = 0,
+):
+    """Apply preprocessors to a concat dataset.
+
+    Parameters
+    ----------
+    concat_ds: BaseConcatDataset
+        A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
+    preprocessors: list(Preprocessor)
+        List of Preprocessor objects to apply to the dataset.
+    save_dir : str | None
+        If a string, the preprocessed data will be saved under the specified
+        directory and the datasets in ``concat_ds`` will be reloaded with
+        `preload=False`.
+    overwrite : bool
+        When `save_dir` is provided, controls whether to delete the old
+        subdirectories that will be written to under `save_dir`. If False and
+        the corresponding subdirectories already exist, a ``FileExistsError``
+        will be raised.
+    n_jobs : int | None
+        Number of jobs for parallel execution. See `joblib.Parallel` for
+        a more detailed explanation.
+    offset : int
+        If provided, the integer is added to the id of the dataset in the
+        concat. This is useful in the setting of very large datasets, where
+        one dataset has to be processed and saved at a time to account for
+        its original position.
+
+    Returns
+    -------
+    BaseConcatDataset:
+        Preprocessed dataset.
+    """
+    # In case of serialization, make sure directory is available before
+    # preprocessing
+    if save_dir is not None and not overwrite:
+        _check_save_dir_empty(save_dir)
+
+    if not isinstance(preprocessors, Iterable):
+        raise ValueError("preprocessors must be a list of Preprocessor objects.")
+    for elem in preprocessors:
+        assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
+
+    parallel_processing = (n_jobs is not None) and (n_jobs != 1)
+
+    job_prefer = "threads" if platform.system() == "Windows" else None
+    list_of_ds = Parallel(n_jobs=n_jobs, prefer=job_prefer)(
+        delayed(_preprocess)(
+            ds,
+            i + offset,
+            preprocessors,
+            save_dir,
+            overwrite,
+            copy_data=(parallel_processing and (save_dir is None)),
+        )
+        for i, ds in enumerate(concat_ds.datasets)
+    )
+
+    if save_dir is not None:  # Reload datasets and replace in concat_ds
+        ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
+        concat_ds_reloaded = load_concat_dataset(
+            save_dir,
+            preload=False,
+            target_name=None,
+            ids_to_load=ids_to_load,
+        )
+        _replace_inplace(concat_ds, concat_ds_reloaded)
+    else:
+        if parallel_processing:  # joblib made copies
+            _replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
+        else:  # joblib did not make copies, the
+            # preprocessing happened in-place
+            # Recompute cumulative sizes as transforms might have changed them
+            concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
+
+    return concat_ds
+
+
+def _replace_inplace(concat_ds, new_concat_ds):
+    """Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
+
+    Parameters
+    ----------
+    concat_ds : BaseConcatDataset
+        Dataset to modify inplace.
+    new_concat_ds : BaseConcatDataset
+        Dataset to use to modify ``concat_ds``.
+    """
+    if len(concat_ds.datasets) != len(new_concat_ds.datasets):
+        raise ValueError("Both inputs must have the same length.")
+    for i in range(len(new_concat_ds.datasets)):
+        concat_ds.datasets[i] = new_concat_ds.datasets[i]
+
+    concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
+    preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
+    if hasattr(new_concat_ds, preproc_kwargs_attr):
+        setattr(
+            concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
+        )
+
+
+def _preprocess(
+    ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
+):
+    """Apply preprocessor(s) to Raw or Epochs object.
+
+    Parameters
+    ----------
+    ds: BaseDataset | WindowsDataset
+        Dataset object to preprocess.
+    ds_index : int
+        Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
+        is None.
+    preprocessors: list(Preprocessor)
+        List of preprocessors to apply to the dataset.
+    save_dir : str | None
+        If provided, save the preprocessed BaseDataset in the
+        specified directory.
+    overwrite : bool
+        If True, overwrite existing file with the same name.
+    copy_data : bool
+        First copy the data in case it is preloaded. Necessary for parallel processing to work.
+    """
+
+    def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
+        # Copying the data necessary in some scenarios for parallel processing
+        # to work when data is in memory (else error about _data not being writeable)
+        if raw_or_epochs.preload and copy_data:
+            raw_or_epochs._data = raw_or_epochs._data.copy()
+        for preproc in preprocessors:
+            preproc.apply(raw_or_epochs)
+
+    if hasattr(ds, "raw"):
+        if isinstance(ds, EEGWindowsDataset):
+            warn(
+                f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
+            )
+        _preprocess_raw_or_epochs(ds.raw, preprocessors)
+    elif hasattr(ds, "windows"):
+        _preprocess_raw_or_epochs(ds.windows, preprocessors)
+    else:
+        raise ValueError(
+            "Can only preprocess concatenation of BaseDataset or "
+            "WindowsDataset, with either a `raw` or `windows` attribute."
+        )
+
+    # Store preprocessing keyword arguments in the dataset
+    _set_preproc_kwargs(ds, preprocessors)
+
+    if save_dir is not None:
+        concat_ds = BaseConcatDataset([ds])
+        concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
+    else:
+        return ds
+
+
+def _get_preproc_kwargs(preprocessors):
+    preproc_kwargs = []
+    for p in preprocessors:
+        # in case of a mne function, fn is a str, kwargs is a dict
+        func_name = p.fn
+        func_kwargs = p.kwargs
+        # in case of another function
+        # if apply_on_array=False
+        if callable(p.fn):
+            func_name = p.fn.__name__
+        # if apply_on_array=True
+        else:
+            if "fun" in p.fn:
+                func_name = p.kwargs["fun"].func.__name__
+                func_kwargs = p.kwargs["fun"].keywords
+        preproc_kwargs.append((func_name, func_kwargs))
+    return preproc_kwargs
+
+
+def _set_preproc_kwargs(ds, preprocessors):
+    """Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
+
+    Parameters
+    ----------
+    ds : BaseDataset | WindowsDataset
+        Dataset in which to record preprocessing keyword arguments.
+    preprocessors : list
+        List of preprocessors.
+    """
+    preproc_kwargs = _get_preproc_kwargs(preprocessors)
+    if isinstance(ds, WindowsDataset):
+        kind = "window"
+    if isinstance(ds, EEGWindowsDataset):
+        kind = "raw"
+    elif isinstance(ds, BaseDataset):
+        kind = "raw"
+    else:
+        raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
+    setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
+
+
+def exponential_moving_standardize(
+    data: NDArray,
+    factor_new: float = 0.001,
+    init_block_size: int | None = None,
+    eps: float = 1e-4,
+):
+    r"""Perform exponential moving standardization.
+
+    Compute the exponental moving mean :math:`m_t` at time `t` as
+    :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
+
+    Then, compute exponential moving variance :math:`v_t` at time `t` as
+    :math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
+
+    Finally, standardize the data point :math:`x_t` at time `t` as:
+    :math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
+
+
+    Parameters
+    ----------
+    data: np.ndarray (n_channels, n_times)
+    factor_new: float
+    init_block_size: int
+        Standardize data before to this index with regular standardization.
+    eps: float
+        Stabilizer for division by zero variance.
+
+    Returns
+    -------
+    standardized: np.ndarray (n_channels, n_times)
+        Standardized data.
+    """
+    data = data.T
+    df = pd.DataFrame(data)
+    meaned = df.ewm(alpha=factor_new).mean()
+    demeaned = df - meaned
+    squared = demeaned * demeaned
+    square_ewmed = squared.ewm(alpha=factor_new).mean()
+    standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
+    standardized = np.array(standardized)
+    if init_block_size is not None:
+        i_time_axis = 0
+        init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
+        init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
+        init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
+            eps, init_std
+        )
+        standardized[0:init_block_size] = init_block_standardized
+    return standardized.T
+
+
+def exponential_moving_demean(
+    data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
+):
+    r"""Perform exponential moving demeanining.
+
+    Compute the exponental moving mean :math:`m_t` at time `t` as
+    :math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
+
+    Deman the data point :math:`x_t` at time `t` as:
+    :math:`x'_t=(x_t - m_t)`.
+
+    Parameters
+    ----------
+    data: np.ndarray (n_channels, n_times)
+    factor_new: float
+    init_block_size: int
+        Demean data before to this index with regular demeaning.
+
+    Returns
+    -------
+    demeaned: np.ndarray (n_channels, n_times)
+        Demeaned data.
+    """
+    data = data.T
+    df = pd.DataFrame(data)
+    meaned = df.ewm(alpha=factor_new).mean()
+    demeaned = df - meaned
+    demeaned = np.array(demeaned)
+    if init_block_size is not None:
+        i_time_axis = 0
+        init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
+        demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
+    return demeaned.T
+
+
+def filterbank(
+    raw: BaseRaw,
+    frequency_bands: list[tuple[float, float]],
+    drop_original_signals: bool = True,
+    order_by_frequency_band: bool = False,
+    **mne_filter_kwargs,
+):
+    """Applies multiple bandpass filters to the signals in raw. The raw will be
+    modified in-place and number of channels in raw will be updated to
+    len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
+    drop_original_signals).
+
+    Parameters
+    ----------
+    raw: mne.io.Raw
+        The raw signals to be filtered.
+    frequency_bands: list(tuple)
+        The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
+    drop_original_signals: bool
+        Whether to drop the original unfiltered signals
+    order_by_frequency_band: bool
+        If True will return channels ordered by frequency bands, so if there
+        are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
+        channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
+        will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
+    mne_filter_kwargs: dict
+        Keyword arguments for filtering supported by mne.io.Raw.filter().
+        Please refer to mne for a detailed explanation.
+    """
+    if not frequency_bands:
+        raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
+    if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
+        warn(
+            "Try to use shorter channel names, since frequency band "
+            "annotation requires an estimated 4-8 chars depending on the "
+            "frequency ranges. Will truncate to 15 chars (mne max)."
+        )
+    original_ch_names = raw.ch_names
+    all_filtered = []
+    for l_freq, h_freq in frequency_bands:
+        filtered = raw.copy()
+        filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
+        # mne automatically changes the highpass/lowpass info values
+        # when applying filters and channels can't be added if they have
+        # different such parameters. Not needed when making picks as
+        # high pass is not modified by filter if pick is specified
+
+        ch_names = filtered.info.ch_names
+        ch_types = filtered.info.get_channel_types()
+        sampling_freq = filtered.info["sfreq"]
+
+        info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
+
+        filtered.info = info
+
+        # add frequency band annotation to channel names
+        # truncate to a max of 15 characters, since mne does not allow for more
+        filtered.rename_channels(
+            {
+                old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
+                for old_name in filtered.ch_names
+            }
+        )
+        all_filtered.append(filtered)
+    raw.add_channels(all_filtered)
+    if not order_by_frequency_band:
+        # order channels by name and not by frequency band:
+        # index the list with a stepsize of the number of channels for each of
+        # the original channels
+        chs_by_freq_band = []
+        for i in range(len(original_ch_names)):
+            chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
+        raw.reorder_channels(chs_by_freq_band)
+    if drop_original_signals:
+        raw.drop_channels(original_ch_names)