"""Compute a Recursively Applied and Projected MUltiple Signal Classification (RAP-MUSIC).""" # noqa
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy import linalg
from .._fiff.pick import pick_channels_forward, pick_info
from ..fixes import _safe_svd
from ..forward import convert_forward_solution, is_fixed_orient
from ..inverse_sparse.mxne_inverse import _make_dipoles_sparse
from ..minimum_norm.inverse import _log_exp_var
from ..utils import _check_info_inv, fill_doc, logger, verbose
from ._compute_beamformer import _prepare_beamformer_input
@fill_doc
def _apply_rap_music(
data, info, times, forward, noise_cov, n_dipoles=2, picks=None, use_trap=False
):
"""RAP-MUSIC or TRAP-MUSIC for evoked data.
Parameters
----------
data : array, shape (n_channels, n_times)
Evoked data.
%(info_not_none)s
times : array
Times.
forward : instance of Forward
Forward operator.
noise_cov : instance of Covariance
The noise covariance.
n_dipoles : int
The number of dipoles to estimate. The default value is 2.
picks : list of int
Caller ensures this is a list of int.
use_trap : bool
Use the TRAP-MUSIC variant if True (default False).
Returns
-------
dipoles : list of instances of Dipole
The dipole fits.
explained_data : array | None
Data explained by the dipoles using a least square fitting with the
selected active dipoles and their estimated orientation.
"""
info = pick_info(info, picks)
del picks
# things are much simpler if we avoid surface orientation
align = forward["source_nn"].copy()
if forward["surf_ori"] and not is_fixed_orient(forward):
forward = convert_forward_solution(forward, surf_ori=False)
is_free_ori, info, _, _, G, whitener, _, _ = _prepare_beamformer_input(
info, forward, noise_cov=noise_cov, rank=None
)
forward = pick_channels_forward(forward, info["ch_names"], ordered=True)
del info
# whiten the data (leadfield already whitened)
M = np.dot(whitener, data)
del data
_, eig_vectors = linalg.eigh(np.dot(M, M.T))
phi_sig = eig_vectors[:, -n_dipoles:]
n_orient = 3 if is_free_ori else 1
G.shape = (G.shape[0], -1, n_orient)
gain = forward["sol"]["data"].copy()
gain.shape = G.shape
n_channels = G.shape[0]
A = np.empty((n_channels, n_dipoles))
gain_dip = np.empty((n_channels, n_dipoles))
oris = np.empty((n_dipoles, 3))
poss = np.empty((n_dipoles, 3))
G_proj = G.copy()
phi_sig_proj = phi_sig.copy()
idxs = list()
for k in range(n_dipoles):
subcorr_max = -1.0
source_idx, source_ori, source_pos = 0, [0, 0, 0], [0, 0, 0]
for i_source in range(G.shape[1]):
Gk = G_proj[:, i_source]
subcorr, ori = _compute_subcorr(Gk, phi_sig_proj)
if subcorr > subcorr_max:
subcorr_max = subcorr
source_idx = i_source
source_ori = ori
source_pos = forward["source_rr"][i_source]
if n_orient == 3 and align is not None:
surf_normal = forward["source_nn"][3 * i_source + 2]
# make sure ori is aligned to the surface orientation
source_ori *= np.sign(source_ori @ surf_normal) or 1.0
if n_orient == 1:
source_ori = forward["source_nn"][i_source]
idxs.append(source_idx)
if n_orient == 3:
Ak = np.dot(G[:, source_idx], source_ori)
else:
Ak = G[:, source_idx, 0]
A[:, k] = Ak
oris[k] = source_ori
poss[k] = source_pos
logger.info(f"source {k + 1} found: p = {source_idx}")
if n_orient == 3:
logger.info("ori = {} {} {}".format(*tuple(oris[k])))
projection = _compute_proj(A[:, : k + 1])
G_proj = np.einsum("ab,bso->aso", projection, G)
phi_sig_proj = np.dot(projection, phi_sig)
if use_trap:
phi_sig_proj = phi_sig_proj[:, -(n_dipoles - k) :]
del G, G_proj
sol = linalg.lstsq(A, M)[0]
if n_orient == 3:
X = sol[:, np.newaxis] * oris[:, :, np.newaxis]
X.shape = (-1, len(times))
else:
X = sol
gain_active = gain[:, idxs]
if n_orient == 3:
gain_dip = (oris * gain_active).sum(-1)
idxs = np.array(idxs)
active_set = np.array([[3 * idxs, 3 * idxs + 1, 3 * idxs + 2]]).T.ravel()
else:
gain_dip = gain_active[:, :, 0]
active_set = idxs
gain_active = whitener @ gain_active.reshape(gain.shape[0], -1)
assert gain_active.shape == (n_channels, X.shape[0])
explained_data = gain_dip @ sol
M_estimate = whitener @ explained_data
_log_exp_var(M, M_estimate)
tstep = np.median(np.diff(times)) if len(times) > 1 else 1.0
dipoles = _make_dipoles_sparse(
X, active_set, forward, times[0], tstep, M, gain_active, active_is_idx=True
)
for dipole, ori in zip(dipoles, oris):
signs = np.sign((dipole.ori * ori).sum(-1, keepdims=True))
dipole._ori *= signs
dipole._amplitude *= signs[:, 0]
logger.info("[done]")
return dipoles, explained_data
def _compute_subcorr(G, phi_sig):
"""Compute the subspace correlation."""
Ug, Sg, Vg = _safe_svd(G, full_matrices=False)
# Now we look at the actual rank of the forward fields
# in G and handle the fact that it might be rank defficient
# eg. when using MEG and a sphere model for which the
# radial component will be truly 0.
rank = np.sum(Sg > (Sg[0] * 1e-6))
if rank == 0:
return 0, np.zeros(len(G))
rank = max(rank, 2) # rank cannot be 1
Ug, Sg, Vg = Ug[:, :rank], Sg[:rank], Vg[:rank]
tmp = np.dot(Ug.T.conjugate(), phi_sig)
Uc, Sc, _ = _safe_svd(tmp, full_matrices=False)
X = np.dot(Vg.T / Sg[None, :], Uc[:, 0]) # subcorr
return Sc[0], X / np.linalg.norm(X)
def _compute_proj(A):
"""Compute the orthogonal projection operation for a manifold vector A."""
U, _, _ = _safe_svd(A, full_matrices=False)
return np.identity(A.shape[0]) - np.dot(U, U.T.conjugate())
def _rap_music(evoked, forward, noise_cov, n_dipoles, return_residual, use_trap):
"""RAP-/TRAP-MUSIC implementation."""
info = evoked.info
data = evoked.data
times = evoked.times
picks = _check_info_inv(info, forward, data_cov=None, noise_cov=noise_cov)
data = data[picks]
dipoles, explained_data = _apply_rap_music(
data, info, times, forward, noise_cov, n_dipoles, picks, use_trap
)
if return_residual:
residual = evoked.copy().pick([info["ch_names"][p] for p in picks])
residual.data -= explained_data
active_projs = [p for p in residual.info["projs"] if p["active"]]
for p in active_projs:
p["active"] = False
residual.add_proj(active_projs, remove_existing=True)
residual.apply_proj()
return dipoles, residual
else:
return dipoles
@verbose
def rap_music(
evoked,
forward,
noise_cov,
n_dipoles=5,
return_residual=False,
*,
verbose=None,
):
"""RAP-MUSIC source localization method.
Compute Recursively Applied and Projected MUltiple SIgnal Classification
(RAP-MUSIC) :footcite:`MosherLeahy1999,MosherLeahy1996` on evoked data.
.. note:: The goodness of fit (GOF) of all the returned dipoles is the
same and corresponds to the GOF of the full set of dipoles.
Parameters
----------
evoked : instance of Evoked
Evoked data to localize.
forward : instance of Forward
Forward operator.
noise_cov : instance of Covariance
The noise covariance.
n_dipoles : int
The number of dipoles to look for. The default value is 5.
return_residual : bool
If True, the residual is returned as an Evoked instance.
%(verbose)s
Returns
-------
dipoles : list of instance of Dipole
The dipole fits.
residual : instance of Evoked
The residual a.k.a. data not explained by the dipoles.
Only returned if return_residual is True.
See Also
--------
mne.fit_dipole
mne.beamformer.trap_music
Notes
-----
.. versionadded:: 0.9.0
References
----------
.. footbibliography::
"""
return _rap_music(evoked, forward, noise_cov, n_dipoles, return_residual, False)
@verbose
def trap_music(
evoked,
forward,
noise_cov,
n_dipoles=5,
return_residual=False,
*,
verbose=None,
):
"""TRAP-MUSIC source localization method.
Compute Truncated Recursively Applied and Projected MUltiple SIgnal Classification
(TRAP-MUSIC) :footcite:`Makela2018` on evoked data.
.. note:: The goodness of fit (GOF) of all the returned dipoles is the
same and corresponds to the GOF of the full set of dipoles.
Parameters
----------
evoked : instance of Evoked
Evoked data to localize.
forward : instance of Forward
Forward operator.
noise_cov : instance of Covariance
The noise covariance.
n_dipoles : int
The number of dipoles to look for. The default value is 5.
return_residual : bool
If True, the residual is returned as an Evoked instance.
%(verbose)s
Returns
-------
dipoles : list of instance of Dipole
The dipole fits.
residual : instance of Evoked
The residual a.k.a. data not explained by the dipoles.
Only returned if return_residual is True.
See Also
--------
mne.fit_dipole
mne.beamformer.rap_music
Notes
-----
.. versionadded:: 1.4
References
----------
.. footbibliography::
"""
return _rap_music(evoked, forward, noise_cov, n_dipoles, return_residual, True)