# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from numpy.polynomial.legendre import legval
from scipy.interpolate import RectBivariateSpline
from scipy.linalg import pinv
from scipy.spatial.distance import pdist, squareform
from .._fiff.meas_info import _simplify_info
from .._fiff.pick import pick_channels, pick_info, pick_types
from ..surface import _normalize_vectors
from ..utils import _validate_type, logger, verbose, warn
def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
"""Calculate spherical spline h function between points on a sphere.
Parameters
----------
cosang : array-like | float
cosine of angles between pairs of points on a spherical surface. This
is equivalent to the dot product of unit vectors.
stiffness : float
stiffnes of the spline. Also referred to as ``m``.
n_legendre_terms : int
number of Legendre terms to evaluate.
"""
factors = [
(2 * n + 1) / (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi)
for n in range(1, n_legendre_terms + 1)
]
return legval(cosang, [0] + factors)
def _calc_g(cosang, stiffness=4, n_legendre_terms=50):
"""Calculate spherical spline g function between points on a sphere.
Parameters
----------
cosang : array-like of float, shape(n_channels, n_channels)
cosine of angles between pairs of points on a spherical surface. This
is equivalent to the dot product of unit vectors.
stiffness : float
stiffness of the spline.
n_legendre_terms : int
number of Legendre terms to evaluate.
Returns
-------
G : np.ndrarray of float, shape(n_channels, n_channels)
The G matrix.
"""
factors = [
(2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
for n in range(1, n_legendre_terms + 1)
]
return legval(cosang, [0] + factors)
def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
"""Compute interpolation matrix based on spherical splines.
Implementation based on [1]
Parameters
----------
pos_from : np.ndarray of float, shape(n_good_sensors, 3)
The positions to interpolate from.
pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
The positions to interpolate.
alpha : float
Regularization parameter. Defaults to 1e-5.
Returns
-------
interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
The interpolation matrix that maps good signals to the location
of bad signals.
References
----------
[1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
Spherical splines for scalp potential and current density mapping.
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
"""
pos_from = pos_from.copy()
pos_to = pos_to.copy()
n_from = pos_from.shape[0]
n_to = pos_to.shape[0]
# normalize sensor positions to sphere
_normalize_vectors(pos_from)
_normalize_vectors(pos_to)
# cosine angles between source positions
cosang_from = pos_from.dot(pos_from.T)
cosang_to_from = pos_to.dot(pos_from.T)
G_from = _calc_g(cosang_from)
G_to_from = _calc_g(cosang_to_from)
assert G_from.shape == (n_from, n_from)
assert G_to_from.shape == (n_to, n_from)
if alpha is not None:
G_from.flat[:: len(G_from) + 1] += alpha
C = np.vstack(
[
np.hstack([G_from, np.ones((n_from, 1))]),
np.hstack([np.ones((1, n_from)), [[0]]]),
]
)
C_inv = pinv(C)
interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1]
assert interpolation.shape == (n_to, n_from)
return interpolation
def _do_interp_dots(inst, interpolation, goods_idx, bads_idx):
"""Dot product of channel mapping matrix to channel data."""
from ..epochs import BaseEpochs
from ..evoked import Evoked
from ..io import BaseRaw
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst")
inst._data[..., bads_idx, :] = np.matmul(
interpolation, inst._data[..., goods_idx, :]
)
@verbose
def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None):
if exclude is None:
exclude = list()
bads_idx = np.zeros(len(inst.ch_names), dtype=bool)
goods_idx = np.zeros(len(inst.ch_names), dtype=bool)
picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude)
inst.info._check_consistency()
bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks]
if len(picks) == 0 or bads_idx.sum() == 0:
return
goods_idx[picks] = True
goods_idx[bads_idx] = False
pos = inst._get_channel_positions(picks)
# Make sure only EEG are used
bads_idx_pos = bads_idx[picks]
goods_idx_pos = goods_idx[picks]
# test spherical fit
distance = np.linalg.norm(pos - origin, axis=-1)
distance = np.mean(distance / np.mean(distance))
if np.abs(1.0 - distance) > 0.1:
warn(
"Your spherical fit is poor, interpolation results are "
"likely to be inaccurate."
)
pos_good = pos[goods_idx_pos] - origin
pos_bad = pos[bads_idx_pos] - origin
logger.info(f"Computing interpolation matrix from {len(pos_good)} sensor positions")
interpolation = _make_interpolation_matrix(pos_good, pos_bad)
logger.info(f"Interpolating {len(pos_bad)} sensors")
_do_interp_dots(inst, interpolation, goods_idx, bads_idx)
@verbose
def _interpolate_bads_ecog(inst, origin, exclude=None, verbose=None):
_interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose)
def _interpolate_bads_meg(
inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False
):
return _interpolate_bads_meeg(
inst, mode, origin, ref_meg=ref_meg, eeg=False, verbose=verbose
)
@verbose
def _interpolate_bads_nan(
inst,
ch_type,
ref_meg=False,
exclude=(),
*,
verbose=None,
):
info = _simplify_info(inst.info)
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True})
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
if len(bads_type) == 0 or len(picks_type) == 0:
return
# select the bad channels to be interpolated
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
inst._data[..., picks_bad, :] = np.nan
@verbose
def _interpolate_bads_meeg(
inst,
mode="accurate",
origin=(0.0, 0.0, 0.04),
meg=True,
eeg=True,
ref_meg=False,
exclude=(),
*,
method=None,
verbose=None,
):
from ..forward import _map_meg_or_eeg_channels
if method is None:
method = {"meg": "MNE", "eeg": "MNE"}
bools = dict(meg=meg, eeg=eeg)
info = _simplify_info(inst.info)
for ch_type, do in bools.items():
if not do:
continue
kw = dict(meg=False, eeg=False)
kw[ch_type] = True
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **kw)
picks_good = pick_types(info, ref_meg=ref_meg, exclude="bads", **kw)
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
if len(bads_type) == 0 or len(picks_type) == 0:
continue
# select the bad channels to be interpolated
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
# do MNE based interpolation
if ch_type == "eeg":
picks_to = picks_type
bad_sel = np.isin(picks_type, picks_bad)
else:
picks_to = picks_bad
bad_sel = slice(None)
info_from = pick_info(inst.info, picks_good)
info_to = pick_info(inst.info, picks_to)
mapping = _map_meg_or_eeg_channels(info_from, info_to, mode=mode, origin=origin)
mapping = mapping[bad_sel]
_do_interp_dots(inst, mapping, picks_good, picks_bad)
@verbose
def _interpolate_bads_nirs(inst, exclude=(), verbose=None):
from mne.preprocessing.nirs import _validate_nirs_info
if len(pick_types(inst.info, fnirs=True, exclude=())) == 0:
return
# Returns pick of all nirs and ensures channels are correctly ordered
picks_nirs = _validate_nirs_info(inst.info)
nirs_ch_names = [inst.info["ch_names"][p] for p in picks_nirs]
nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude]
bads_nirs = [ch for ch in inst.info["bads"] if ch in nirs_ch_names]
if len(bads_nirs) == 0:
return
picks_bad = pick_channels(inst.info["ch_names"], bads_nirs, exclude=[])
bads_mask = [p in picks_bad for p in picks_nirs]
chs = [inst.info["chs"][i] for i in picks_nirs]
locs3d = np.array([ch["loc"][:3] for ch in chs])
dist = pdist(locs3d)
dist = squareform(dist)
for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]
# TODO: this seems like a bug because it does not respect reset_bads
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]
return inst
def _find_seeg_electrode_shaft(pos, tol_shaft=0.002, tol_spacing=1):
# 1) find nearest neighbor to define the electrode shaft line
# 2) find all contacts on the same line
# 3) remove contacts with large distances
dist = squareform(pdist(pos))
np.fill_diagonal(dist, np.inf)
shafts = list()
shaft_ts = list()
for i, n1 in enumerate(pos):
if any([i in shaft for shaft in shafts]):
continue
n2 = pos[np.argmin(dist[i])] # 1
# https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
shaft_dists = np.linalg.norm(
np.cross((pos - n1), (pos - n2)), axis=1
) / np.linalg.norm(n2 - n1)
shaft = np.where(shaft_dists < tol_shaft)[0] # 2
shaft_prev = None
for _ in range(10): # avoid potential cycles
if np.array_equal(shaft, shaft_prev):
break
shaft_prev = shaft
# compute median shaft line
v = np.median(
[
pos[i] - pos[j]
for idx, i in enumerate(shaft)
for j in shaft[idx + 1 :]
],
axis=0,
)
c = np.median(pos[shaft], axis=0)
# recompute distances
shaft_dists = np.linalg.norm(
np.cross((pos - c), (pos - c + v)), axis=1
) / np.linalg.norm(v)
shaft = np.where(shaft_dists < tol_shaft)[0]
ts = np.array([np.dot(c - n0, v) / np.linalg.norm(v) ** 2 for n0 in pos[shaft]])
shaft_order = np.argsort(ts)
shaft = shaft[shaft_order]
ts = ts[shaft_order]
# only include the largest group with spacing with the error tolerance
# avoid interpolating across spans between contacts
t_diffs = np.diff(ts)
t_diff_med = np.median(t_diffs)
spacing_errors = (t_diffs - t_diff_med) / t_diff_med
groups = list()
group = [shaft[0]]
for j in range(len(shaft) - 1):
if spacing_errors[j] > tol_spacing:
groups.append(group)
group = [shaft[j + 1]]
else:
group.append(shaft[j + 1])
groups.append(group)
group = [group for group in groups if i in group][0]
ts = ts[np.isin(shaft, group)]
shaft = np.array(group, dtype=int)
shafts.append(shaft)
shaft_ts.append(ts)
return shafts, shaft_ts
@verbose
def _interpolate_bads_seeg(
inst, exclude=None, tol_shaft=0.002, tol_spacing=1, verbose=None
):
if exclude is None:
exclude = list()
picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude)
inst.info._check_consistency()
bads_idx = np.isin(np.array(inst.ch_names)[picks], inst.info["bads"])
if len(picks) == 0 or bads_idx.sum() == 0:
return
pos = inst._get_channel_positions(picks)
# Make sure only sEEG are used
bads_idx_pos = bads_idx[picks]
shafts, shaft_ts = _find_seeg_electrode_shaft(
pos, tol_shaft=tol_shaft, tol_spacing=tol_spacing
)
# interpolate the bad contacts
picks_bad = list(np.where(bads_idx_pos)[0])
for shaft, ts in zip(shafts, shaft_ts):
bads_shaft = np.array([idx for idx in picks_bad if idx in shaft])
if bads_shaft.size == 0:
continue
goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)]
if goods_shaft.size < 4: # cubic spline requires 3 channels
msg = "No shaft" if shaft.size < 4 else "Not enough good channels"
no_shaft_chs = " and ".join(np.array(inst.ch_names)[bads_shaft])
raise RuntimeError(
f"{msg} found in a line with {no_shaft_chs} "
"at least 3 good channels on the same line "
f"are required for interpolation, {goods_shaft.size} found. "
f"Dropping {no_shaft_chs} is recommended."
)
logger.debug(
f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using "
f"data from {np.array(inst.ch_names)[goods_shaft]}"
)
bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0]
goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0]
z = inst._data[..., goods_shaft, :]
is_epochs = z.ndim == 3
if is_epochs:
z = z.swapaxes(0, 1)
z = z.reshape(z.shape[0], -1)
y = np.arange(z.shape[-1])
out = RectBivariateSpline(x=ts[goods_shaft_idx], y=y, z=z)(
x=ts[bads_shaft_idx], y=y
)
if is_epochs:
out = out.reshape(bads_shaft.size, inst._data.shape[0], -1)
out = out.swapaxes(0, 1)
inst._data[..., bads_shaft, :] = out