"""Bad channel detection using Local Outlier Factor (LOF)."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.pick import _picks_to_idx
from ..io.base import BaseRaw
from ..utils import _soft_import, _validate_type, logger, verbose
@verbose
def find_bad_channels_lof(
raw,
n_neighbors=20,
*,
picks=None,
metric="euclidean",
threshold=1.5,
return_scores=False,
verbose=None,
):
"""Find bad channels using Local Outlier Factor (LOF) algorithm.
Parameters
----------
raw : instance of Raw
Raw data to process.
n_neighbors : int
Number of neighbors defining the local neighborhood (default is 20).
Smaller values will lead to higher LOF scores.
%(picks_good_data)s
metric : str
Metric to use for distance computation. Default is “euclidean”,
see :func:`sklearn.metrics.pairwise.distance_metrics` for details.
threshold : float
Threshold to define outliers. Theoretical threshold ranges anywhere
between 1.0 and any positive integer. Default: 1.5
It is recommended to consider this as an hyperparameter to optimize.
return_scores : bool
If ``True``, return a dictionary with LOF scores for each
evaluated channel. Default is ``False``.
%(verbose)s
Returns
-------
noisy_chs : list
List of bad M/EEG channels that were automatically detected.
scores : ndarray, shape (n_picks,)
Only returned when ``return_scores`` is ``True``. It contains the
LOF outlier score for each channel in ``picks``.
See Also
--------
maxwell_filter
annotate_amplitude
Notes
-----
See :footcite:`KumaravelEtAl2022` and :footcite:`BreunigEtAl2000` for background on
choosing ``threshold``.
.. versionadded:: 1.7
References
----------
.. footbibliography::
""" # noqa: E501
_soft_import("sklearn", "using LOF detection", strict=True)
from sklearn.neighbors import LocalOutlierFactor
_validate_type(raw, BaseRaw, "raw")
# Get the channel types
channel_types = raw.get_channel_types()
picks = _picks_to_idx(raw.info, picks=picks, none="data", exclude="bads")
picked_ch_types = set(channel_types[p] for p in picks)
# Check if there are different channel types
if len(picked_ch_types) != 1:
raise ValueError(
f"Need exactly one channel type in picks, got {sorted(picked_ch_types)}"
)
ch_names = [raw.ch_names[pick] for pick in picks]
data = raw.get_data(picks=picks)
clf = LocalOutlierFactor(n_neighbors=n_neighbors, metric=metric)
clf.fit_predict(data)
scores_lof = clf.negative_outlier_factor_
bad_channel_indices = [
i for i, v in enumerate(np.abs(scores_lof)) if v >= threshold
]
bads = [ch_names[idx] for idx in bad_channel_indices]
logger.info(f"LOF: Detected bad channel(s): {bads}")
if return_scores:
return bads, scores_lof
else:
return bads