"""Tools for data interpolation."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from itertools import chain
import numpy as np
from scipy.sparse.csgraph import connected_components
from .._fiff.meas_info import create_info
from ..epochs import BaseEpochs, EpochsArray
from ..evoked import Evoked, EvokedArray
from ..io import BaseRaw, RawArray
from ..transforms import _cart_to_sph, _sph_to_cart
from ..utils import _ensure_int, _validate_type
def equalize_bads(insts, interp_thresh=1.0, copy=True):
"""Interpolate or mark bads consistently for a list of instances.
Once called on a list of instances, the instances can be concatenated
as they will have the same list of bad channels.
Parameters
----------
insts : list
The list of instances (Evoked, Epochs or Raw) to consider
for interpolation. Each instance should have marked channels.
interp_thresh : float
A float between 0 and 1 (default) that specifies the fraction of time
a channel should be good to be eventually interpolated for certain
instances. For example if 0.5, a channel which is good at least half
of the time will be interpolated in the instances where it is marked
as bad. If 1 then channels will never be interpolated and if 0 all bad
channels will be systematically interpolated.
copy : bool
If True then the returned instances will be copies.
Returns
-------
insts_bads : list
The list of instances, with the same channel(s) marked as bad in all of
them, possibly with some formerly bad channels interpolated.
"""
if not 0 <= interp_thresh <= 1:
raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}")
all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts])))
if isinstance(insts[0], BaseEpochs):
durations = [len(inst) * len(inst.times) for inst in insts]
else:
durations = [len(inst.times) for inst in insts]
good_times = []
for ch_name in all_bads:
good_times.append(
sum(
durations[k]
for k, inst in enumerate(insts)
if ch_name not in inst.info["bads"]
)
/ np.sum(durations)
)
bads_keep = [ch for k, ch in enumerate(all_bads) if good_times[k] < interp_thresh]
if copy:
insts = [inst.copy() for inst in insts]
for inst in insts:
if len(set(inst.info["bads"]) - set(bads_keep)):
inst.interpolate_bads(exclude=bads_keep)
inst.info["bads"] = bads_keep
return insts
def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4):
"""Interpolate bridged electrode pairs.
Because bridged electrodes contain brain signal, it's just that the
signal is spatially smeared between the two electrodes, we can
make a virtual channel midway between the bridged pairs and use
that to aid in interpolation rather than completely discarding the
data from the two channels.
Parameters
----------
inst : instance of Epochs, Evoked, or Raw
The data object with channels that are to be interpolated.
bridged_idx : list of tuple
The indices of channels marked as bridged with each bridged
pair stored as a tuple.
bad_limit : int
The maximum number of electrodes that can be bridged together
(included) and interpolated. Above this number, an error will be
raised.
.. versionadded:: 1.2
Returns
-------
inst : instance of Epochs, Evoked, or Raw
The modified data object.
See Also
--------
mne.preprocessing.compute_bridged_electrodes
"""
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked))
bad_limit = _ensure_int(bad_limit, "bad_limit")
if bad_limit <= 0:
raise ValueError(
"Argument 'bad_limit' should be a strictly positive "
f"integer. Provided {bad_limit} is invalid."
)
montage = inst.get_montage()
if montage is None:
raise RuntimeError("No channel positions found in ``inst``")
pos = montage.get_positions()
if pos["coord_frame"] != "head":
raise RuntimeError(
f"Montage channel positions must be in ``head`` got {pos['coord_frame']}"
)
# store bads orig to put back at the end
bads_orig = inst.info["bads"]
inst.info["bads"] = list()
# look for group of bad channels
nodes = sorted(set(chain(*bridged_idx)))
G_dense = np.zeros((len(nodes), len(nodes)))
# fill the edges with a weight of 1
for bridge in bridged_idx:
idx0 = np.searchsorted(nodes, bridge[0])
idx1 = np.searchsorted(nodes, bridge[1])
G_dense[idx0, idx1] = 1
G_dense[idx1, idx0] = 1
# look for connected components
_, labels = connected_components(G_dense, directed=False)
groups_idx = [[nodes[j] for j in np.where(labels == k)[0]] for k in set(labels)]
groups_names = [
[inst.info.ch_names[k] for k in group_idx] for group_idx in groups_idx
]
# warn for all bridged areas that include too many electrodes
for group_names in groups_names:
if len(group_names) > bad_limit:
raise RuntimeError(
f"The channels {', '.join(group_names)} are bridged together "
"and form a large area of bridged electrodes. Interpolation "
"might be inaccurate."
)
# make virtual channels
virtual_chs = dict()
bads = set()
for k, group_idx in enumerate(groups_idx):
group_names = [inst.info.ch_names[k] for k in group_idx]
bads = bads.union(group_names)
# compute centroid position in spherical "head" coordinates
pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names)
# create the virtual channel info and set the position
virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg")
virtual_info["chs"][0]["loc"][:3] = pos_virtual
# create virtual channel
data = inst.get_data(picks=group_names)
if isinstance(inst, BaseRaw):
data = np.average(data, axis=0).reshape(1, -1)
virtual_ch = RawArray(data, virtual_info, first_samp=inst.first_samp)
elif isinstance(inst, BaseEpochs):
data = np.average(data, axis=1).reshape(len(data), 1, -1)
virtual_ch = EpochsArray(data, virtual_info, tmin=inst.tmin)
else: # evoked
data = np.average(data, axis=0).reshape(1, -1)
virtual_ch = EvokedArray(
np.average(data, axis=0).reshape(1, -1),
virtual_info,
tmin=inst.tmin,
nave=inst.nave,
kind=inst.kind,
)
virtual_chs[f"virtual {k + 1}"] = virtual_ch
# add the virtual channels
inst.add_channels(list(virtual_chs.values()), force_update_info=True)
# use the virtual channels to interpolate
inst.info["bads"] = list(bads)
inst.interpolate_bads()
# drop virtual channels
inst.drop_channels(list(virtual_chs.keys()))
inst.info["bads"] = bads_orig
return inst
def _find_centroid_sphere(ch_pos, group_names):
"""Compute the centroid position between N electrodes.
The centroid should be determined in spherical "head" coordinates which is
more accurante than cutting through the scalp by averaging in cartesian
coordinates.
A simple way is to average the location in cartesian coordinate, convert
to spehrical coordinate and replace the radius with the average radius of
the N points in spherical coordinates.
Parameters
----------
ch_pos : OrderedDict
The position of all channels in cartesian coordinates.
group_names : list | tuple
The name of the N electrodes used to determine the centroid.
Returns
-------
pos_centroid : array of shape (3,)
The position of the centroid in cartesian coordinates.
"""
cartesian_positions = np.array([ch_pos[ch_name] for ch_name in group_names])
sphere_positions = _cart_to_sph(cartesian_positions)
cartesian_pos_centroid = np.average(cartesian_positions, axis=0)
sphere_pos_centroid = _cart_to_sph(cartesian_pos_centroid)
# average the radius and overwrite it
avg_radius = np.average(sphere_positions, axis=0)[0]
sphere_pos_centroid[0, 0] = avg_radius
# convert back to cartesian
pos_centroid = _sph_to_cart(sphere_pos_centroid)[0, :]
return pos_centroid