# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
import pytest
from numpy.testing import assert_allclose
from scipy import linalg
import mne
from mne.beamformer import rap_music, trap_music
from mne.cov import regularize
from mne.datasets import testing
from mne.minimum_norm.tests.test_inverse import assert_var_exp_log
from mne.utils import catch_logging
data_path = testing.data_path(download=False)
fname_ave = data_path / "MEG" / "sample" / "sample_audvis-ave.fif"
fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif"
fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif"
def _get_data(ch_decim=1):
"""Read in data used in tests."""
# Read evoked
evoked = mne.read_evokeds(fname_ave, 0, baseline=(None, 0))
evoked.info["bads"] = ["MEG 2443"]
with evoked.info._unlock():
evoked.info["lowpass"] = 16 # fake for decim
evoked.decimate(12)
evoked.crop(0.0, 0.3)
picks = mne.pick_types(evoked.info, meg=True, eeg=False)
picks = picks[::ch_decim]
evoked.pick([evoked.ch_names[pick] for pick in picks])
evoked.info.normalize_proj()
noise_cov = mne.read_cov(fname_cov)
noise_cov["projs"] = []
noise_cov = regularize(noise_cov, evoked.info, rank="full", proj=False)
return evoked, noise_cov
def simu_data(evoked, forward, noise_cov, n_dipoles, times, nave=1):
"""Simulate an evoked dataset with 2 sources.
One source is put in each hemisphere.
"""
# Generate the two dipoles data
mu, sigma = 0.1, 0.005
s1 = (
1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((times - mu) ** 2) / (2 * sigma**2))
)
mu, sigma = 0.075, 0.008
s2 = (
-1
/ (sigma * np.sqrt(2 * np.pi))
* np.exp(-((times - mu) ** 2) / (2 * sigma**2))
)
data = np.array([s1, s2]) * 1e-9
src = forward["src"]
rng = np.random.RandomState(42)
rndi = rng.randint(len(src[0]["vertno"]))
lh_vertno = src[0]["vertno"][[rndi]]
rndi = rng.randint(len(src[1]["vertno"]))
rh_vertno = src[1]["vertno"][[rndi]]
vertices = [lh_vertno, rh_vertno]
tmin, tstep = times.min(), 1 / evoked.info["sfreq"]
stc = mne.SourceEstimate(data, vertices=vertices, tmin=tmin, tstep=tstep)
sim_evoked = mne.simulation.simulate_evoked(
forward, stc, evoked.info, noise_cov, nave=nave, random_state=rng
)
return sim_evoked, stc
def _check_dipoles(dipoles, fwd, stc, evoked, residual=None):
src = fwd["src"]
pos1 = fwd["source_rr"][np.where(src[0]["vertno"] == stc.vertices[0])]
pos2 = fwd["source_rr"][
np.where(src[1]["vertno"] == stc.vertices[1])[0] + len(src[0]["vertno"])
]
# Check the position of the two dipoles
assert dipoles[0].pos[0] in np.array([pos1, pos2])
assert dipoles[1].pos[0] in np.array([pos1, pos2])
ori1 = fwd["source_nn"][np.where(src[0]["vertno"] == stc.vertices[0])[0]][0]
ori2 = fwd["source_nn"][
np.where(src[1]["vertno"] == stc.vertices[1])[0] + len(src[0]["vertno"])
][0]
# Check the orientation of the dipoles
assert np.max(np.abs(np.dot(dipoles[0].ori[0], np.array([ori1, ori2]).T))) > 0.99
assert np.max(np.abs(np.dot(dipoles[1].ori[0], np.array([ori1, ori2]).T))) > 0.99
if residual is not None:
picks_grad = mne.pick_types(residual.info, meg="grad")
picks_mag = mne.pick_types(residual.info, meg="mag")
rel_tol = 0.02
for picks in [picks_grad, picks_mag]:
assert linalg.norm(residual.data[picks], ord="fro") < rel_tol * linalg.norm(
evoked.data[picks], ord="fro"
)
@testing.requires_testing_data
def test_rap_music_simulated():
"""Test RAP-MUSIC with simulated evoked."""
evoked, noise_cov = _get_data(ch_decim=16)
forward = mne.read_forward_solution(fname_fwd)
forward = mne.pick_channels_forward(forward, evoked.ch_names)
forward_surf_ori = mne.convert_forward_solution(forward, surf_ori=True)
forward_fixed = mne.convert_forward_solution(
forward, force_fixed=True, surf_ori=True, use_cps=True
)
n_dipoles = 2
sim_evoked, stc = simu_data(
evoked, forward_fixed, noise_cov, n_dipoles, evoked.times, nave=evoked.nave
)
# Check dipoles for fixed ori
with catch_logging() as log:
dipoles = rap_music(
sim_evoked, forward_fixed, noise_cov, n_dipoles=n_dipoles, verbose=True
)
assert_var_exp_log(log.getvalue(), 89, 91)
_check_dipoles(dipoles, forward_fixed, stc, sim_evoked)
assert 97 < dipoles[0].gof.max() < 100
assert 91 < dipoles[1].gof.max() < 93
assert dipoles[0].gof.min() >= 0.0
nave = 100000 # add a tiny amount of noise to the simulated evokeds
sim_evoked, stc = simu_data(
evoked, forward_fixed, noise_cov, n_dipoles, evoked.times, nave=nave
)
dipoles, residual = rap_music(
sim_evoked, forward_fixed, noise_cov, n_dipoles=n_dipoles, return_residual=True
)
_check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual)
# Check dipoles for free ori
dipoles, residual = rap_music(
sim_evoked, forward, noise_cov, n_dipoles=n_dipoles, return_residual=True
)
_check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual)
# Check dipoles for free surface ori
dipoles, residual = rap_music(
sim_evoked,
forward_surf_ori,
noise_cov,
n_dipoles=n_dipoles,
return_residual=True,
)
_check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual)
@pytest.mark.slowtest
@testing.requires_testing_data
def test_rap_music_sphere():
"""Test RAP-MUSIC with real data, sphere model, MEG only."""
evoked, noise_cov = _get_data(ch_decim=8)
sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.04))
src = mne.setup_volume_source_space(
subject=None,
pos=10.0,
sphere=(0.0, 0.0, 40, 65.0),
mindist=5.0,
exclude=0.0,
sphere_units="mm",
)
forward = mne.make_forward_solution(evoked.info, trans=None, src=src, bem=sphere)
with catch_logging() as log:
dipoles = rap_music(evoked, forward, noise_cov, n_dipoles=2, verbose=True)
assert_var_exp_log(log.getvalue(), 47, 49)
# Test that there is one dipole on each hemisphere
pos = np.array([dip.pos[0] for dip in dipoles])
assert pos.shape == (2, 3)
assert (pos[:, 0] < 0).sum() == 1
assert (pos[:, 0] > 0).sum() == 1
# Check the amplitude scale
assert 1e-10 < dipoles[0].amplitude[0] < 1e-7
# Check the orientation
dip_fit = mne.fit_dipole(evoked, noise_cov, sphere)[0]
assert np.max(np.abs(np.dot(dip_fit.ori, dipoles[0].ori[0]))) > 0.99
assert np.max(np.abs(np.dot(dip_fit.ori, dipoles[1].ori[0]))) > 0.99
idx = dip_fit.gof.argmax()
dist = np.linalg.norm(dipoles[0].pos[idx] - dip_fit.pos[idx])
assert 0.004 <= dist < 0.007
assert_allclose(dipoles[0].gof[idx], dip_fit.gof[idx], atol=3)
@testing.requires_testing_data
def test_rap_music_picks():
"""Test RAP-MUSIC with picking."""
evoked = mne.read_evokeds(fname_ave, condition="Right Auditory", baseline=(None, 0))
evoked.crop(tmin=0.05, tmax=0.15) # select N100
evoked.pick(picks="meg")
forward = mne.read_forward_solution(fname_fwd)
noise_cov = mne.read_cov(fname_cov)
dipoles = rap_music(evoked, forward, noise_cov, n_dipoles=2)
assert len(dipoles) == 2
@testing.requires_testing_data
def test_trap_music():
"""Test TRAP-MUSIC."""
evoked = mne.read_evokeds(fname_ave, condition="Right Auditory", baseline=(None, 0))
evoked.crop(tmin=0.05, tmax=0.15) # select N100
evoked.pick(picks="meg")
forward = mne.read_forward_solution(fname_fwd)
noise_cov = mne.read_cov(fname_cov)
dipoles = trap_music(evoked, forward, noise_cov, n_dipoles=2)
assert len(dipoles) == 2