"""Preprocessors that work on Raw or Epochs objects."""
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
# Lukas Gemein <l.gemein@gmail.com>
# Simon Brandt <simonbrandt@protonmail.com>
# David Sabbagh <dav.sabbagh@gmail.com>
# Bruno Aristimunha <b.aristimunha@gmail.com>
#
# License: BSD (3-clause)
from __future__ import annotations
from warnings import warn
from functools import partial
from collections.abc import Iterable
import sys
import platform
if sys.version_info < (3, 9):
from typing import Callable
else:
from collections.abc import Callable
import numpy as np
from numpy.typing import NDArray
import pandas as pd
from mne import create_info, BaseEpochs
from mne.io import BaseRaw
from joblib import Parallel, delayed
from braindecode.datasets.base import (
BaseConcatDataset,
BaseDataset,
WindowsDataset,
EEGWindowsDataset,
)
from braindecode.datautil.serialization import (
load_concat_dataset,
_check_save_dir_empty,
)
class Preprocessor(object):
"""Preprocessor for an MNE Raw or Epochs object.
Applies the provided preprocessing function to the data of a Raw or Epochs
object.
If the function is provided as a string, the method with that name will be
used (e.g., 'pick_channels', 'filter', etc.).
If it is provided as a callable and `apply_on_array` is True, the
`apply_function` method of Raw and Epochs object will be used to apply the
function on the internal arrays of Raw and Epochs.
If `apply_on_array` is False, the callable must directly modify the Raw or
Epochs object (e.g., by calling its method(s) or modifying its attributes).
Parameters
----------
fn: str or callable
If str, the Raw/Epochs object must have a method with that name.
If callable, directly apply the callable to the object.
apply_on_array : bool
Ignored if `fn` is not a callable. If True, the `apply_function` of Raw
and Epochs object will be used to run `fn` on the underlying arrays
directly. If False, `fn` must directly modify the Raw or Epochs object.
kwargs:
Keyword arguments to be forwarded to the MNE function.
"""
def __init__(self, fn: Callable | str, *, apply_on_array: bool = True, **kwargs):
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>":
warn("Preprocessing choices with lambda functions cannot be saved.")
if callable(fn) and apply_on_array:
channel_wise = kwargs.pop("channel_wise", False)
picks = kwargs.pop("picks", None)
n_jobs = kwargs.pop("n_jobs", 1)
kwargs = dict(
fun=partial(fn, **kwargs),
channel_wise=channel_wise,
picks=picks,
n_jobs=n_jobs,
)
fn = "apply_function"
self.fn = fn
self.kwargs = kwargs
def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
try:
self._try_apply(raw_or_epochs)
except RuntimeError:
# Maybe the function needs the data to be loaded and the data was
# not loaded yet. Not all MNE functions need data to be loaded,
# most importantly the 'crop' function can be lazily applied
# without preloading data which can make the overall preprocessing
# pipeline substantially faster.
raw_or_epochs.load_data()
self._try_apply(raw_or_epochs)
def _try_apply(self, raw_or_epochs):
if callable(self.fn):
self.fn(raw_or_epochs, **self.kwargs)
else:
if not hasattr(raw_or_epochs, self.fn):
raise AttributeError(f"MNE object does not have a {self.fn} method.")
getattr(raw_or_epochs, self.fn)(**self.kwargs)
def preprocess(
concat_ds: BaseConcatDataset,
preprocessors: list[Preprocessor],
save_dir: str | None = None,
overwrite: bool = False,
n_jobs: int | None = None,
offset: int = 0,
):
"""Apply preprocessors to a concat dataset.
Parameters
----------
concat_ds: BaseConcatDataset
A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
preprocessors: list(Preprocessor)
List of Preprocessor objects to apply to the dataset.
save_dir : str | None
If a string, the preprocessed data will be saved under the specified
directory and the datasets in ``concat_ds`` will be reloaded with
`preload=False`.
overwrite : bool
When `save_dir` is provided, controls whether to delete the old
subdirectories that will be written to under `save_dir`. If False and
the corresponding subdirectories already exist, a ``FileExistsError``
will be raised.
n_jobs : int | None
Number of jobs for parallel execution. See `joblib.Parallel` for
a more detailed explanation.
offset : int
If provided, the integer is added to the id of the dataset in the
concat. This is useful in the setting of very large datasets, where
one dataset has to be processed and saved at a time to account for
its original position.
Returns
-------
BaseConcatDataset:
Preprocessed dataset.
"""
# In case of serialization, make sure directory is available before
# preprocessing
if save_dir is not None and not overwrite:
_check_save_dir_empty(save_dir)
if not isinstance(preprocessors, Iterable):
raise ValueError("preprocessors must be a list of Preprocessor objects.")
for elem in preprocessors:
assert hasattr(elem, "apply"), "Preprocessor object needs an `apply` method."
parallel_processing = (n_jobs is not None) and (n_jobs != 1)
job_prefer = "threads" if platform.system() == "Windows" else None
list_of_ds = Parallel(n_jobs=n_jobs, prefer=job_prefer)(
delayed(_preprocess)(
ds,
i + offset,
preprocessors,
save_dir,
overwrite,
copy_data=(parallel_processing and (save_dir is None)),
)
for i, ds in enumerate(concat_ds.datasets)
)
if save_dir is not None: # Reload datasets and replace in concat_ds
ids_to_load = [i + offset for i in range(len(concat_ds.datasets))]
concat_ds_reloaded = load_concat_dataset(
save_dir,
preload=False,
target_name=None,
ids_to_load=ids_to_load,
)
_replace_inplace(concat_ds, concat_ds_reloaded)
else:
if parallel_processing: # joblib made copies
_replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
else: # joblib did not make copies, the
# preprocessing happened in-place
# Recompute cumulative sizes as transforms might have changed them
concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
return concat_ds
def _replace_inplace(concat_ds, new_concat_ds):
"""Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
Parameters
----------
concat_ds : BaseConcatDataset
Dataset to modify inplace.
new_concat_ds : BaseConcatDataset
Dataset to use to modify ``concat_ds``.
"""
if len(concat_ds.datasets) != len(new_concat_ds.datasets):
raise ValueError("Both inputs must have the same length.")
for i in range(len(new_concat_ds.datasets)):
concat_ds.datasets[i] = new_concat_ds.datasets[i]
concat_kind = "raw" if hasattr(concat_ds.datasets[0], "raw") else "window"
preproc_kwargs_attr = concat_kind + "_preproc_kwargs"
if hasattr(new_concat_ds, preproc_kwargs_attr):
setattr(
concat_ds, preproc_kwargs_attr, getattr(new_concat_ds, preproc_kwargs_attr)
)
def _preprocess(
ds, ds_index, preprocessors, save_dir=None, overwrite=False, copy_data=False
):
"""Apply preprocessor(s) to Raw or Epochs object.
Parameters
----------
ds: BaseDataset | WindowsDataset
Dataset object to preprocess.
ds_index : int
Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
is None.
preprocessors: list(Preprocessor)
List of preprocessors to apply to the dataset.
save_dir : str | None
If provided, save the preprocessed BaseDataset in the
specified directory.
overwrite : bool
If True, overwrite existing file with the same name.
copy_data : bool
First copy the data in case it is preloaded. Necessary for parallel processing to work.
"""
def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
# Copying the data necessary in some scenarios for parallel processing
# to work when data is in memory (else error about _data not being writeable)
if raw_or_epochs.preload and copy_data:
raw_or_epochs._data = raw_or_epochs._data.copy()
for preproc in preprocessors:
preproc.apply(raw_or_epochs)
if hasattr(ds, "raw"):
if isinstance(ds, EEGWindowsDataset):
warn(
f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
)
_preprocess_raw_or_epochs(ds.raw, preprocessors)
elif hasattr(ds, "windows"):
_preprocess_raw_or_epochs(ds.windows, preprocessors)
else:
raise ValueError(
"Can only preprocess concatenation of BaseDataset or "
"WindowsDataset, with either a `raw` or `windows` attribute."
)
# Store preprocessing keyword arguments in the dataset
_set_preproc_kwargs(ds, preprocessors)
if save_dir is not None:
concat_ds = BaseConcatDataset([ds])
concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
else:
return ds
def _get_preproc_kwargs(preprocessors):
preproc_kwargs = []
for p in preprocessors:
# in case of a mne function, fn is a str, kwargs is a dict
func_name = p.fn
func_kwargs = p.kwargs
# in case of another function
# if apply_on_array=False
if callable(p.fn):
func_name = p.fn.__name__
# if apply_on_array=True
else:
if "fun" in p.fn:
func_name = p.kwargs["fun"].func.__name__
func_kwargs = p.kwargs["fun"].keywords
preproc_kwargs.append((func_name, func_kwargs))
return preproc_kwargs
def _set_preproc_kwargs(ds, preprocessors):
"""Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
Parameters
----------
ds : BaseDataset | WindowsDataset
Dataset in which to record preprocessing keyword arguments.
preprocessors : list
List of preprocessors.
"""
preproc_kwargs = _get_preproc_kwargs(preprocessors)
if isinstance(ds, WindowsDataset):
kind = "window"
if isinstance(ds, EEGWindowsDataset):
kind = "raw"
elif isinstance(ds, BaseDataset):
kind = "raw"
else:
raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
def exponential_moving_standardize(
data: NDArray,
factor_new: float = 0.001,
init_block_size: int | None = None,
eps: float = 1e-4,
):
r"""Perform exponential moving standardization.
Compute the exponental moving mean :math:`m_t` at time `t` as
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
Then, compute exponential moving variance :math:`v_t` at time `t` as
:math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
Finally, standardize the data point :math:`x_t` at time `t` as:
:math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
Parameters
----------
data: np.ndarray (n_channels, n_times)
factor_new: float
init_block_size: int
Standardize data before to this index with regular standardization.
eps: float
Stabilizer for division by zero variance.
Returns
-------
standardized: np.ndarray (n_channels, n_times)
Standardized data.
"""
data = data.T
df = pd.DataFrame(data)
meaned = df.ewm(alpha=factor_new).mean()
demeaned = df - meaned
squared = demeaned * demeaned
square_ewmed = squared.ewm(alpha=factor_new).mean()
standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
standardized = np.array(standardized)
if init_block_size is not None:
i_time_axis = 0
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
init_std = np.std(data[0:init_block_size], axis=i_time_axis, keepdims=True)
init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(
eps, init_std
)
standardized[0:init_block_size] = init_block_standardized
return standardized.T
def exponential_moving_demean(
data: NDArray, factor_new: float = 0.001, init_block_size: int | None = None
):
r"""Perform exponential moving demeanining.
Compute the exponental moving mean :math:`m_t` at time `t` as
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
Deman the data point :math:`x_t` at time `t` as:
:math:`x'_t=(x_t - m_t)`.
Parameters
----------
data: np.ndarray (n_channels, n_times)
factor_new: float
init_block_size: int
Demean data before to this index with regular demeaning.
Returns
-------
demeaned: np.ndarray (n_channels, n_times)
Demeaned data.
"""
data = data.T
df = pd.DataFrame(data)
meaned = df.ewm(alpha=factor_new).mean()
demeaned = df - meaned
demeaned = np.array(demeaned)
if init_block_size is not None:
i_time_axis = 0
init_mean = np.mean(data[0:init_block_size], axis=i_time_axis, keepdims=True)
demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
return demeaned.T
def filterbank(
raw: BaseRaw,
frequency_bands: list[tuple[float, float]],
drop_original_signals: bool = True,
order_by_frequency_band: bool = False,
**mne_filter_kwargs,
):
"""Applies multiple bandpass filters to the signals in raw. The raw will be
modified in-place and number of channels in raw will be updated to
len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
drop_original_signals).
Parameters
----------
raw: mne.io.Raw
The raw signals to be filtered.
frequency_bands: list(tuple)
The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
drop_original_signals: bool
Whether to drop the original unfiltered signals
order_by_frequency_band: bool
If True will return channels ordered by frequency bands, so if there
are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
mne_filter_kwargs: dict
Keyword arguments for filtering supported by mne.io.Raw.filter().
Please refer to mne for a detailed explanation.
"""
if not frequency_bands:
raise ValueError(f"Expected at least one frequency band, got {frequency_bands}")
if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
warn(
"Try to use shorter channel names, since frequency band "
"annotation requires an estimated 4-8 chars depending on the "
"frequency ranges. Will truncate to 15 chars (mne max)."
)
original_ch_names = raw.ch_names
all_filtered = []
for l_freq, h_freq in frequency_bands:
filtered = raw.copy()
filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
# mne automatically changes the highpass/lowpass info values
# when applying filters and channels can't be added if they have
# different such parameters. Not needed when making picks as
# high pass is not modified by filter if pick is specified
ch_names = filtered.info.ch_names
ch_types = filtered.info.get_channel_types()
sampling_freq = filtered.info["sfreq"]
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
filtered.info = info
# add frequency band annotation to channel names
# truncate to a max of 15 characters, since mne does not allow for more
filtered.rename_channels(
{
old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
for old_name in filtered.ch_names
}
)
all_filtered.append(filtered)
raw.add_channels(all_filtered)
if not order_by_frequency_band:
# order channels by name and not by frequency band:
# index the list with a stepsize of the number of channels for each of
# the original channels
chs_by_freq_band = []
for i in range(len(original_ch_names)):
chs_by_freq_band.extend(raw.ch_names[i :: len(original_ch_names)])
raw.reorder_channels(chs_by_freq_band)
if drop_original_signals:
raw.drop_channels(original_ch_names)