# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
import pytest
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_array_equal,
assert_array_less,
)
import mne
from mne import convert_forward_solution, read_cov, read_evokeds, read_forward_solution
from mne.datasets import testing
from mne.dipole import Dipole
from mne.inverse_sparse import mixed_norm, tf_mixed_norm
from mne.inverse_sparse.mxne_inverse import (
_compute_mxne_sure,
_split_gof,
make_stc_from_dipoles,
)
from mne.inverse_sparse.mxne_optim import norm_l2inf
from mne.label import read_label
from mne.minimum_norm import apply_inverse, make_inverse_operator
from mne.minimum_norm.tests.test_inverse import assert_stc_res, assert_var_exp_log
from mne.simulation import simulate_evoked, simulate_sparse_stc
from mne.source_estimate import VolSourceEstimate
from mne.utils import _record_warnings, assert_stcs_equal, catch_logging
data_path = testing.data_path(download=False)
# NOTE: These use the ave and cov from sample dataset (no _trunc)
fname_data = data_path / "MEG" / "sample" / "sample_audvis-ave.fif"
fname_cov = data_path / "MEG" / "sample" / "sample_audvis-cov.fif"
fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif"
fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif"
label = "Aud-rh"
fname_label = data_path / "MEG" / "sample" / "labels" / f"{label}.label"
@pytest.fixture(scope="module", params=[testing._pytest_param])
def forward():
"""Get a forward solution."""
# module scope it for speed (but don't overwrite in use!)
return read_forward_solution(fname_fwd)
@testing.requires_testing_data
@pytest.mark.timeout(150) # ~30 s on Travis Linux
@pytest.mark.slowtest
def test_mxne_inverse_standard(forward):
"""Test (TF-)MxNE inverse computation."""
# Read noise covariance matrix
cov = read_cov(fname_cov)
# Handling average file
loose = 0.0
depth = 0.9
evoked = read_evokeds(fname_data, condition=0, baseline=(None, 0))
evoked.crop(tmin=-0.05, tmax=0.2)
evoked_l21 = evoked.copy()
evoked_l21.crop(tmin=0.081, tmax=0.1)
label = read_label(fname_label)
assert label.hemi == "rh"
forward = convert_forward_solution(forward, surf_ori=True)
# Reduce source space to make test computation faster
inverse_operator = make_inverse_operator(
evoked_l21.info,
forward,
cov,
loose=loose,
depth=depth,
fixed=True,
use_cps=True,
)
stc_dspm = apply_inverse(
evoked_l21, inverse_operator, lambda2=1.0 / 9.0, method="dSPM"
)
stc_dspm.data[np.abs(stc_dspm.data) < 12] = 0.0
stc_dspm.data[np.abs(stc_dspm.data) >= 12] = 1.0
weights_min = 0.5
# MxNE tests
alpha = 70 # spatial regularization parameter
with _record_warnings(): # CD
stc_cd = mixed_norm(
evoked_l21,
forward,
cov,
alpha,
loose=loose,
depth=depth,
maxit=300,
tol=1e-8,
active_set_size=10,
weights=stc_dspm,
weights_min=weights_min,
solver="cd",
)
stc_bcd = mixed_norm(
evoked_l21,
forward,
cov,
alpha,
loose=loose,
depth=depth,
maxit=300,
tol=1e-8,
active_set_size=10,
weights=stc_dspm,
weights_min=weights_min,
solver="bcd",
)
assert_array_almost_equal(stc_cd.times, evoked_l21.times, 5)
assert_array_almost_equal(stc_bcd.times, evoked_l21.times, 5)
assert_allclose(stc_cd.data, stc_bcd.data, rtol=1e-3, atol=0.0)
assert stc_cd.vertices[1][0] in label.vertices
assert stc_bcd.vertices[1][0] in label.vertices
# vector
with _record_warnings(): # no convergence
stc = mixed_norm(evoked_l21, forward, cov, alpha, loose=1, maxit=2)
with _record_warnings(): # no convergence
stc_vec = mixed_norm(
evoked_l21, forward, cov, alpha, loose=1, maxit=2, pick_ori="vector"
)
assert_stcs_equal(stc_vec.magnitude(), stc)
with _record_warnings(), pytest.raises(ValueError, match="pick_ori="):
mixed_norm(evoked_l21, forward, cov, alpha, loose=0, maxit=2, pick_ori="vector")
with _record_warnings(), catch_logging() as log: # CD
dips = mixed_norm(
evoked_l21,
forward,
cov,
alpha,
loose=loose,
depth=depth,
maxit=300,
tol=1e-8,
active_set_size=10,
weights=stc_dspm,
weights_min=weights_min,
solver="cd",
return_as_dipoles=True,
verbose=True,
)
stc_dip = make_stc_from_dipoles(dips, forward["src"])
assert isinstance(dips[0], Dipole)
assert stc_dip.subject == "sample"
assert_stcs_equal(stc_cd, stc_dip)
assert_var_exp_log(log.getvalue(), 51, 53) # 51.8
# Single time point things should match
with _record_warnings(), catch_logging() as log:
dips = mixed_norm(
evoked_l21.copy().crop(0.081, 0.081),
forward,
cov,
alpha,
loose=loose,
depth=depth,
maxit=300,
tol=1e-8,
active_set_size=10,
weights=stc_dspm,
weights_min=weights_min,
solver="cd",
return_as_dipoles=True,
verbose=True,
)
assert_var_exp_log(log.getvalue(), 37.8, 38.0) # 37.9
gof = sum(dip.gof[0] for dip in dips) # these are now partial exp vars
assert_allclose(gof, 37.9, atol=0.1)
with _record_warnings(), catch_logging() as log:
stc, res = mixed_norm(
evoked_l21,
forward,
cov,
alpha,
loose=loose,
depth=depth,
maxit=300,
tol=1e-8,
weights=stc_dspm, # gh-6382
active_set_size=10,
return_residual=True,
solver="cd",
verbose=True,
)
assert_array_almost_equal(stc.times, evoked_l21.times, 5)
assert stc.vertices[1][0] in label.vertices
assert_var_exp_log(log.getvalue(), 51, 53) # 51.8
assert stc.data.min() < -1e-9 # signed
assert_stc_res(evoked_l21, stc, forward, res)
# irMxNE tests
with _record_warnings(), catch_logging() as log: # CD
stc, residual = mixed_norm(
evoked_l21,
forward,
cov,
alpha,
n_mxne_iter=5,
loose=0.0001,
depth=depth,
maxit=300,
tol=1e-8,
active_set_size=10,
solver="cd",
return_residual=True,
pick_ori="vector",
verbose=True,
)
assert_array_almost_equal(stc.times, evoked_l21.times, 5)
assert stc.vertices[1][0] in label.vertices
assert stc.vertices == [[63152], [79017]]
assert_var_exp_log(log.getvalue(), 51, 53) # 51.8
assert_stc_res(evoked_l21, stc, forward, residual)
# Do with TF-MxNE for test memory savings
alpha = 60.0 # overall regularization parameter
l1_ratio = 0.01 # temporal regularization proportion
stc, _ = tf_mixed_norm(
evoked,
forward,
cov,
loose=loose,
depth=depth,
maxit=100,
tol=1e-4,
tstep=4,
wsize=16,
window=0.1,
weights=stc_dspm,
weights_min=weights_min,
return_residual=True,
alpha=alpha,
l1_ratio=l1_ratio,
)
assert_array_almost_equal(stc.times, evoked.times, 5)
assert stc.vertices[1][0] in label.vertices
# vector
stc_nrm = tf_mixed_norm(
evoked,
forward,
cov,
loose=1,
depth=depth,
maxit=2,
tol=1e-4,
tstep=4,
wsize=16,
window=0.1,
weights=stc_dspm,
weights_min=weights_min,
alpha=alpha,
l1_ratio=l1_ratio,
)
stc_vec, residual = tf_mixed_norm(
evoked,
forward,
cov,
loose=1,
depth=depth,
maxit=2,
tol=1e-4,
tstep=4,
wsize=16,
window=0.1,
weights=stc_dspm,
weights_min=weights_min,
alpha=alpha,
l1_ratio=l1_ratio,
pick_ori="vector",
return_residual=True,
)
assert_stcs_equal(stc_vec.magnitude(), stc_nrm)
pytest.raises(
ValueError, tf_mixed_norm, evoked, forward, cov, alpha=101, l1_ratio=0.03
)
pytest.raises(
ValueError, tf_mixed_norm, evoked, forward, cov, alpha=50.0, l1_ratio=1.01
)
@pytest.mark.slowtest
@testing.requires_testing_data
def test_mxne_vol_sphere():
"""Test (TF-)MxNE with a sphere forward and volumic source space."""
evoked = read_evokeds(fname_data, condition=0, baseline=(None, 0))
evoked.crop(tmin=-0.05, tmax=0.2)
cov = read_cov(fname_cov)
evoked_l21 = evoked.copy()
evoked_l21.crop(tmin=0.081, tmax=0.1)
info = evoked.info
sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=0.080)
src = mne.setup_volume_source_space(
subject=None,
pos=15.0,
mri=None,
sphere=(0.0, 0.0, 0.0, 0.08),
bem=None,
mindist=5.0,
exclude=2.0,
sphere_units="m",
)
fwd = mne.make_forward_solution(
info, trans=None, src=src, bem=sphere, eeg=False, meg=True
)
alpha = 80.0
# Computing inverse with restricted orientations should also work, since
# we have a discrete source space.
stc = mixed_norm(
evoked_l21,
fwd,
cov,
alpha,
loose=0.2,
return_residual=False,
maxit=3,
tol=1e-8,
active_set_size=10,
)
assert_array_almost_equal(stc.times, evoked_l21.times, 5)
# irMxNE tests
with catch_logging() as log:
stc = mixed_norm(
evoked_l21,
fwd,
cov,
alpha,
n_mxne_iter=1,
maxit=30,
tol=1e-8,
active_set_size=10,
verbose=True,
)
assert isinstance(stc, VolSourceEstimate)
assert_array_almost_equal(stc.times, evoked_l21.times, 5)
assert_var_exp_log(log.getvalue(), 9, 11) # 10.2
# Compare orientation obtained using fit_dipole and gamma_map
# for a simulated evoked containing a single dipole
stc = mne.VolSourceEstimate(
50e-9 * np.random.RandomState(42).randn(1, 4),
vertices=[stc.vertices[0][:1]],
tmin=stc.tmin,
tstep=stc.tstep,
)
evoked_dip = mne.simulation.simulate_evoked(
fwd, stc, info, cov, nave=1e9, use_cps=True
)
dip_mxne = mixed_norm(
evoked_dip,
fwd,
cov,
alpha=80,
n_mxne_iter=1,
maxit=30,
tol=1e-8,
active_set_size=10,
return_as_dipoles=True,
)
amp_max = [np.max(d.amplitude) for d in dip_mxne]
dip_mxne = dip_mxne[np.argmax(amp_max)]
assert dip_mxne.pos[0] in src[0]["rr"][stc.vertices[0]]
dip_fit = mne.fit_dipole(evoked_dip, cov, sphere)[0]
assert np.abs(np.dot(dip_fit.ori[0], dip_mxne.ori[0])) > 0.99
dist = 1000 * np.linalg.norm(dip_fit.pos[0] - dip_mxne.pos[0])
assert dist < 4.0 # within 4 mm
# Do with TF-MxNE for test memory savings
alpha = 60.0 # overall regularization parameter
l1_ratio = 0.01 # temporal regularization proportion
stc, _ = tf_mixed_norm(
evoked,
fwd,
cov,
maxit=3,
tol=1e-4,
tstep=16,
wsize=32,
window=0.1,
alpha=alpha,
l1_ratio=l1_ratio,
return_residual=True,
)
assert isinstance(stc, VolSourceEstimate)
assert_array_almost_equal(stc.times, evoked.times, 5)
@pytest.mark.parametrize("mod", (None, "mult", "augment", "sign", "zero", "less"))
def test_split_gof_basic(mod):
"""Test splitting the goodness of fit."""
# first a trivial case
gain = np.array([[0.0, 1.0, 1.0], [1.0, 1.0, 0.0]]).T
M = np.ones((3, 1))
X = np.ones((2, 1))
M_est = gain @ X
assert_allclose(M_est, np.array([[1.0, 2.0, 1.0]]).T) # a reasonable estimate
if mod == "mult":
gain *= [1.0, -0.5]
X[1] *= -2
elif mod == "augment":
gain = np.concatenate((gain, np.zeros((3, 1))), axis=1)
X = np.concatenate((X, [[1.0]]))
elif mod == "sign":
gain[1] *= -1
M[1] *= -1
M_est[1] *= -1
elif mod in ("zero", "less"):
gain = np.array([[1, 1.0, 1.0], [1.0, 1.0, 1.0]]).T
if mod == "zero":
X[:, 0] = [1.0, 0.0]
else:
X[:, 0] = [1.0, 0.5]
M_est = gain @ X
else:
assert mod is None
res = M - M_est
gof = 100 * (1.0 - (res * res).sum() / (M * M).sum())
gof_split = _split_gof(M, X, gain)
assert_allclose(gof_split.sum(), gof)
want = gof_split[[0, 0]]
if mod == "augment":
want = np.concatenate((want, [[0]]))
if mod in ("mult", "less"):
assert_array_less(gof_split[1], gof_split[0])
elif mod == "zero":
assert_allclose(gof_split[0], gof_split.sum(0))
assert_allclose(gof_split[1], 0.0, atol=1e-6)
else:
assert_allclose(gof_split, want, atol=1e-12)
@testing.requires_testing_data
@pytest.mark.parametrize(
"idx, weights",
[
# empirically determined approximately orthogonal columns: 0, 15157, 19448
([0], [1]),
([0, 15157], [1, 1]),
([0, 15157], [1, 3]),
([0, 15157], [5, -1]),
([0, 15157, 19448], [1, 1, 1]),
([0, 15157, 19448], [1e-2, 1, 5]),
],
)
def test_split_gof_meg(forward, idx, weights):
"""Test GOF splitting on MEG data."""
gain = forward["sol"]["data"][:, idx]
# close to orthogonal
norms = np.linalg.norm(gain, axis=0)
triu = np.triu_indices(len(idx), 1)
prods = np.abs(np.dot(gain.T, gain) / np.outer(norms, norms))[triu]
assert_array_less(prods, 5e-3) # approximately orthogonal
# first, split across time (one dipole per time point)
M = gain * weights
gof_split = _split_gof(M, np.diag(weights), gain)
assert_allclose(gof_split.sum(0), 100.0, atol=1e-5) # all sum to 100
assert_allclose(gof_split, 100 * np.eye(len(weights)), atol=1) # loc
# next, summed to a single time point (all dipoles active at one time pt)
weights = np.array(weights)[:, np.newaxis]
x = gain @ weights
assert x.shape == (gain.shape[0], 1)
gof_split = _split_gof(x, weights, gain)
want = (norms * weights.T).T ** 2
want = 100 * want / want.sum()
assert_allclose(gof_split, want, atol=1e-3, rtol=1e-2)
assert_allclose(gof_split.sum(), 100, rtol=1e-5)
@pytest.mark.parametrize(
"n_sensors, n_dipoles, n_times",
[
(10, 15, 7),
(20, 60, 20),
],
)
@pytest.mark.parametrize("nnz", [2, 4])
@pytest.mark.parametrize("corr", [0.75])
@pytest.mark.parametrize("n_orient", [1, 3])
def test_mxne_inverse_sure_synthetic(
n_sensors, n_dipoles, n_times, nnz, corr, n_orient, snr=4
):
"""Tests SURE criterion for automatic alpha selection on synthetic data."""
rng = np.random.RandomState(0)
sigma = np.sqrt(1 - corr**2)
U = rng.randn(n_sensors)
# generate gain matrix
G = np.empty([n_sensors, n_dipoles], order="F")
G[:, :n_orient] = np.expand_dims(U, axis=-1)
n_dip_per_pos = n_dipoles // n_orient
for j in range(1, n_dip_per_pos):
U *= corr
U += sigma * rng.randn(n_sensors)
G[:, j * n_orient : (j + 1) * n_orient] = np.expand_dims(U, axis=-1)
# generate coefficient matrix
support = rng.choice(n_dip_per_pos, nnz, replace=False)
X = np.zeros((n_dipoles, n_times))
for k in support:
X[k * n_orient : (k + 1) * n_orient, :] = rng.normal(size=(n_orient, n_times))
# generate measurement matrix
M = G @ X
noise = rng.randn(n_sensors, n_times)
sigma = 1 / np.linalg.norm(noise) * np.linalg.norm(M) / snr
M += sigma * noise
# inverse modeling with sure
alpha_max = norm_l2inf(np.dot(G.T, M), n_orient, copy=False)
alpha_grid = np.geomspace(alpha_max, alpha_max / 10, num=15)
_, active_set, _ = _compute_mxne_sure(
M,
G,
alpha_grid,
sigma=sigma,
n_mxne_iter=5,
maxit=3000,
tol=1e-4,
n_orient=n_orient,
active_set_size=10,
debias=True,
solver="auto",
dgap_freq=10,
random_state=0,
verbose=False,
)
assert np.count_nonzero(active_set, axis=-1) == n_orient * nnz
@pytest.mark.slowtest # slow on Azure
@testing.requires_testing_data
def test_mxne_inverse_sure():
"""Tests SURE criterion for automatic alpha selection on MEG data."""
def data_fun(times):
data = np.zeros(times.shape)
data[times >= 0] = 50e-9
return data
n_dipoles = 2
raw = mne.io.read_raw_fif(fname_raw)
info = mne.io.read_info(fname_data)
with info._unlock():
info["projs"] = []
noise_cov = mne.make_ad_hoc_cov(info)
label_names = ["Aud-lh", "Aud-rh"]
labels = [
mne.read_label(data_path / "MEG" / "sample" / "labels" / f"{ln}.label")
for ln in label_names
]
fname_fwd = (
data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif"
)
forward = mne.read_forward_solution(fname_fwd)
forward = mne.pick_types_forward(
forward, meg="grad", eeg=False, exclude=raw.info["bads"]
)
times = np.arange(100, dtype=np.float64) / raw.info["sfreq"] - 0.1
stc = simulate_sparse_stc(
forward["src"],
n_dipoles=n_dipoles,
times=times,
random_state=1,
labels=labels,
data_fun=data_fun,
)
nave = 30
evoked = simulate_evoked(
forward, stc, info, noise_cov, nave=nave, use_cps=False, iir_filter=None
)
evoked = evoked.crop(tmin=0, tmax=10e-3)
stc_ = mixed_norm(evoked, forward, noise_cov, loose=0.9, n_mxne_iter=5, depth=0.9)
assert_array_equal(stc_.vertices, stc.vertices)
@pytest.mark.slowtest # slow on Azure
@testing.requires_testing_data
def test_mxne_inverse_empty():
"""Tests solver with too high alpha."""
evoked = read_evokeds(fname_data, condition=0, baseline=(None, 0))
evoked.pick("grad", exclude="bads")
fname_fwd = (
data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif"
)
forward = mne.read_forward_solution(fname_fwd)
forward = mne.pick_types_forward(
forward, meg="grad", eeg=False, exclude=evoked.info["bads"]
)
cov = read_cov(fname_cov)
with pytest.warns(RuntimeWarning, match="too big"):
stc, residual = mixed_norm(
evoked, forward, cov, n_mxne_iter=3, alpha=99, return_residual=True
)
assert stc.data.size == 0
assert stc.vertices[0].size == 0
assert stc.vertices[1].size == 0
assert_allclose(evoked.data, residual.data)