# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ..fixes import _safe_svd
from ..forward import is_fixed_orient
from ..minimum_norm.inverse import _check_reference, _log_exp_var
from ..utils import logger, verbose, warn
from .mxne_inverse import (
_check_ori,
_compute_residual,
_make_dipoles_sparse,
_make_sparse_stc,
_prepare_gain,
_reapply_source_weighting,
)
@verbose
def _gamma_map_opt(
M,
G,
alpha,
maxit=10000,
tol=1e-6,
update_mode=1,
group_size=1,
gammas=None,
verbose=None,
):
"""Hierarchical Bayes (Gamma-MAP).
Parameters
----------
M : array, shape=(n_sensors, n_times)
Observation.
G : array, shape=(n_sensors, n_sources)
Forward operator.
alpha : float
Regularization parameter (noise variance).
maxit : int
Maximum number of iterations.
tol : float
Tolerance parameter for convergence.
group_size : int
Number of consecutive sources which use the same gamma.
update_mode : int
Update mode, 1: MacKay update (default), 3: Modified MacKay update.
gammas : array, shape=(n_sources,)
Initial values for posterior variances (gammas). If None, a
variance of 1.0 is used.
%(verbose)s
Returns
-------
X : array, shape=(n_active, n_times)
Estimated source time courses.
active_set : array, shape=(n_active,)
Indices of active sources.
"""
G = G.copy()
M = M.copy()
if gammas is None:
gammas = np.ones(G.shape[1], dtype=np.float64)
eps = np.finfo(float).eps
n_sources = G.shape[1]
n_sensors, n_times = M.shape
# apply normalization so the numerical values are sane
M_normalize_constant = np.linalg.norm(np.dot(M, M.T), ord="fro")
M /= np.sqrt(M_normalize_constant)
alpha /= M_normalize_constant
G_normalize_constant = np.linalg.norm(G, ord=np.inf)
G /= G_normalize_constant
if n_sources % group_size != 0:
raise ValueError(
"Number of sources has to be evenly dividable by the group size"
)
n_active = n_sources
active_set = np.arange(n_sources)
gammas_full_old = gammas.copy()
if update_mode == 2:
denom_fun = np.sqrt
else:
# do nothing
def denom_fun(x):
return x
last_size = -1
for itno in range(maxit):
gammas[np.isnan(gammas)] = 0.0
gidx = np.abs(gammas) > eps
active_set = active_set[gidx]
gammas = gammas[gidx]
# update only active gammas (once set to zero it stays at zero)
if n_active > len(active_set):
n_active = active_set.size
G = G[:, gidx]
CM = np.dot(G * gammas[np.newaxis, :], G.T)
CM.flat[:: n_sensors + 1] += alpha
# Invert CM keeping symmetry
U, S, _ = _safe_svd(CM, full_matrices=False)
S = S[np.newaxis, :]
del CM
CMinv = np.dot(U / (S + eps), U.T)
CMinvG = np.dot(CMinv, G)
A = np.dot(CMinvG.T, M) # mult. w. Diag(gamma) in gamma update
if update_mode == 1:
# MacKay fixed point update (10) in [1]
numer = gammas**2 * np.mean((A * A.conj()).real, axis=1)
denom = gammas * np.sum(G * CMinvG, axis=0)
elif update_mode == 2:
# modified MacKay fixed point update (11) in [1]
numer = gammas * np.sqrt(np.mean((A * A.conj()).real, axis=1))
denom = np.sum(G * CMinvG, axis=0) # sqrt is applied below
else:
raise ValueError("Invalid value for update_mode")
if group_size == 1:
if denom is None:
gammas = numer
else:
gammas = numer / np.maximum(denom_fun(denom), np.finfo("float").eps)
else:
numer_comb = np.sum(numer.reshape(-1, group_size), axis=1)
if denom is None:
gammas_comb = numer_comb
else:
denom_comb = np.sum(denom.reshape(-1, group_size), axis=1)
gammas_comb = numer_comb / denom_fun(denom_comb)
gammas = np.repeat(gammas_comb / group_size, group_size)
# compute convergence criterion
gammas_full = np.zeros(n_sources, dtype=np.float64)
gammas_full[active_set] = gammas
err = np.sum(np.abs(gammas_full - gammas_full_old)) / np.sum(
np.abs(gammas_full_old)
)
gammas_full_old = gammas_full
breaking = err < tol or n_active == 0
if len(gammas) != last_size or breaking:
logger.info(
f"Iteration: {itno}\t active set size: {len(gammas)}\t convergence: "
f"{err:.3e}"
)
last_size = len(gammas)
if breaking:
break
if itno < maxit - 1:
logger.info("\nConvergence reached !\n")
else:
warn("\nConvergence NOT reached !\n")
# undo normalization and compute final posterior mean
n_const = np.sqrt(M_normalize_constant) / G_normalize_constant
x_active = n_const * gammas[:, None] * A
return x_active, active_set
@verbose
def gamma_map(
evoked,
forward,
noise_cov,
alpha,
loose="auto",
depth=0.8,
xyz_same_gamma=True,
maxit=10000,
tol=1e-6,
update_mode=1,
gammas=None,
pca=True,
return_residual=False,
return_as_dipoles=False,
rank=None,
pick_ori=None,
verbose=None,
):
"""Hierarchical Bayes (Gamma-MAP) sparse source localization method.
Models each source time course using a zero-mean Gaussian prior with an
unknown variance (gamma) parameter. During estimation, most gammas are
driven to zero, resulting in a sparse source estimate, as in
:footcite:`WipfEtAl2007` and :footcite:`WipfNagarajan2009`.
For fixed-orientation forward operators, a separate gamma is used for each
source time course, while for free-orientation forward operators, the same
gamma is used for the three source time courses at each source space point
(separate gammas can be used in this case by using xyz_same_gamma=False).
Parameters
----------
evoked : instance of Evoked
Evoked data to invert.
forward : dict
Forward operator.
noise_cov : instance of Covariance
Noise covariance to compute whitener.
alpha : float
Regularization parameter (noise variance).
%(loose)s
%(depth)s
xyz_same_gamma : bool
Use same gamma for xyz current components at each source space point.
Recommended for free-orientation forward solutions.
maxit : int
Maximum number of iterations.
tol : float
Tolerance parameter for convergence.
update_mode : int
Update mode, 1: MacKay update (default), 2: Modified MacKay update.
gammas : array, shape=(n_sources,)
Initial values for posterior variances (gammas). If None, a
variance of 1.0 is used.
pca : bool
If True the rank of the data is reduced to the true dimension.
return_residual : bool
If True, the residual is returned as an Evoked instance.
return_as_dipoles : bool
If True, the sources are returned as a list of Dipole instances.
%(rank_none)s
.. versionadded:: 0.18
%(pick_ori)s
%(verbose)s
Returns
-------
stc : instance of SourceEstimate
Source time courses.
residual : instance of Evoked
The residual a.k.a. data not explained by the sources.
Only returned if return_residual is True.
References
----------
.. footbibliography::
"""
_check_reference(evoked)
forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, evoked.info, noise_cov, pca, depth, loose, rank
)
_check_ori(pick_ori, forward)
group_size = 1 if (is_fixed_orient(forward) or not xyz_same_gamma) else 3
# get the data
sel = [evoked.ch_names.index(name) for name in gain_info["ch_names"]]
M = evoked.data[sel]
# whiten the data
logger.info("Whitening data matrix.")
M = np.dot(whitener, M)
# run the optimization
X, active_set = _gamma_map_opt(
M,
gain,
alpha,
maxit=maxit,
tol=tol,
update_mode=update_mode,
gammas=gammas,
group_size=group_size,
verbose=verbose,
)
if len(active_set) == 0:
raise Exception("No active dipoles found. alpha is too big.")
M_estimate = gain[:, active_set] @ X
# Reapply weights to have correct unit
X = _reapply_source_weighting(X, source_weighting, active_set)
if return_residual:
residual = _compute_residual(forward, evoked, X, active_set, gain_info)
if group_size == 1 and not is_fixed_orient(forward):
# make sure each source has 3 components
idx, offset = divmod(active_set, 3)
active_src = np.unique(idx)
if len(X) < 3 * len(active_src):
X_xyz = np.zeros((len(active_src), 3, X.shape[1]), dtype=X.dtype)
idx = np.searchsorted(active_src, idx)
X_xyz[idx, offset, :] = X
X_xyz.shape = (len(active_src) * 3, X.shape[1])
X = X_xyz
active_set = (active_src[:, np.newaxis] * 3 + np.arange(3)).ravel()
source_weighting[source_weighting == 0] = 1 # zeros
gain_active = gain[:, active_set] / source_weighting[active_set]
del source_weighting
tmin = evoked.times[0]
tstep = 1.0 / evoked.info["sfreq"]
if return_as_dipoles:
out = _make_dipoles_sparse(
X, active_set, forward, tmin, tstep, M, gain_active, active_is_idx=True
)
else:
out = _make_sparse_stc(
X,
active_set,
forward,
tmin,
tstep,
active_is_idx=True,
pick_ori=pick_ori,
verbose=verbose,
)
_log_exp_var(M, M_estimate, prefix="")
logger.info("[done]")
if return_residual:
out = out, residual
return out