# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import gc
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_array_equal,
assert_equal,
)
from mne import (
SourceEstimate,
VectorSourceEstimate,
apply_forward,
apply_forward_raw,
average_forward_solutions,
convert_forward_solution,
pick_types_forward,
read_evokeds,
read_forward_solution,
write_forward_solution,
)
from mne._fiff.pick import pick_channels_forward
from mne.channels import equalize_channels
from mne.datasets import testing
from mne.forward import (
Forward,
compute_depth_prior,
compute_orient_prior,
is_fixed_orient,
restrict_forward_to_label,
restrict_forward_to_stc,
)
from mne.io import read_info
from mne.label import read_label
from mne.utils import _record_warnings, requires_mne, run_subprocess
data_path = testing.data_path(download=False)
fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif"
fname_meeg_grad = (
data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-2-grad-fwd.fif"
)
fname_evoked = Path(__file__).parents[2] / "io" / "tests" / "data" / "test-ave.fif"
label_path = data_path / "MEG" / "sample" / "labels"
def assert_forward_allclose(f1, f2, rtol=1e-7):
"""Compare two potentially converted forward solutions."""
assert_allclose(f1["sol"]["data"], f2["sol"]["data"], rtol=rtol)
assert f1["sol"]["ncol"] == f2["sol"]["ncol"]
assert f1["sol"]["ncol"] == f1["sol"]["data"].shape[1]
assert_allclose(f1["source_nn"], f2["source_nn"], rtol=rtol)
if f1["sol_grad"] is not None:
assert f2["sol_grad"] is not None
assert_allclose(f1["sol_grad"]["data"], f2["sol_grad"]["data"])
assert f1["sol_grad"]["ncol"] == f2["sol_grad"]["ncol"]
assert f1["sol_grad"]["ncol"] == f1["sol_grad"]["data"].shape[1]
else:
assert f2["sol_grad"] is None
assert f1["source_ori"] == f2["source_ori"]
assert f1["surf_ori"] == f2["surf_ori"]
assert f1["src"][0]["coord_frame"] == f1["src"][0]["coord_frame"]
@testing.requires_testing_data
def test_convert_forward():
"""Test converting forward solution between different representations."""
fwd = read_forward_solution(fname_meeg_grad)
fwd_repr = repr(fwd)
assert "306" in fwd_repr
assert "60" in fwd_repr
assert fwd_repr
assert isinstance(fwd, Forward)
# look at surface orientation
fwd_surf = convert_forward_solution(fwd, surf_ori=True)
# go back
fwd_new = convert_forward_solution(fwd_surf, surf_ori=False)
assert repr(fwd_new)
assert isinstance(fwd_new, Forward)
assert_forward_allclose(fwd, fwd_new)
del fwd_new
gc.collect()
# now go to fixed
fwd_fixed = convert_forward_solution(
fwd_surf, surf_ori=True, force_fixed=True, use_cps=False
)
del fwd_surf
gc.collect()
assert repr(fwd_fixed)
assert isinstance(fwd_fixed, Forward)
assert is_fixed_orient(fwd_fixed)
# now go back to cartesian (original condition)
fwd_new = convert_forward_solution(fwd_fixed, surf_ori=False, force_fixed=False)
assert repr(fwd_new)
assert isinstance(fwd_new, Forward)
assert_forward_allclose(fwd, fwd_new)
del fwd, fwd_new, fwd_fixed
gc.collect()
@pytest.mark.slowtest
@testing.requires_testing_data
def test_io_forward(tmp_path):
"""Test IO for forward solutions."""
# do extensive tests with MEEG + grad
n_channels, n_src = 366, 108
fwd = read_forward_solution(fname_meeg_grad)
assert isinstance(fwd, Forward)
fwd = read_forward_solution(fname_meeg_grad)
fwd = convert_forward_solution(fwd, surf_ori=True)
leadfield = fwd["sol"]["data"]
assert_equal(leadfield.shape, (n_channels, n_src))
assert_equal(len(fwd["sol"]["row_names"]), n_channels)
fname_temp = tmp_path / "test-fwd.fif"
with pytest.warns(RuntimeWarning, match="stored on disk"):
write_forward_solution(fname_temp, fwd, overwrite=True)
fwd = read_forward_solution(fname_meeg_grad)
fwd = convert_forward_solution(fwd, surf_ori=True)
fwd_read = read_forward_solution(fname_temp)
fwd_read = convert_forward_solution(fwd_read, surf_ori=True)
leadfield = fwd_read["sol"]["data"]
assert_equal(leadfield.shape, (n_channels, n_src))
assert_equal(len(fwd_read["sol"]["row_names"]), n_channels)
assert_equal(len(fwd_read["info"]["chs"]), n_channels)
assert "dev_head_t" in fwd_read["info"]
assert "mri_head_t" in fwd_read
assert_array_almost_equal(fwd["sol"]["data"], fwd_read["sol"]["data"])
fwd = read_forward_solution(fname_meeg)
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=False)
with pytest.warns(RuntimeWarning, match="stored on disk"):
write_forward_solution(fname_temp, fwd, overwrite=True)
fwd_read = read_forward_solution(fname_temp)
fwd_read = convert_forward_solution(
fwd_read, surf_ori=True, force_fixed=True, use_cps=False
)
assert repr(fwd_read)
assert isinstance(fwd_read, Forward)
assert is_fixed_orient(fwd_read)
assert_forward_allclose(fwd, fwd_read)
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)
leadfield = fwd["sol"]["data"]
assert_equal(leadfield.shape, (n_channels, 1494 / 3))
assert_equal(len(fwd["sol"]["row_names"]), n_channels)
assert_equal(len(fwd["info"]["chs"]), n_channels)
assert "dev_head_t" in fwd["info"]
assert "mri_head_t" in fwd
assert fwd["surf_ori"]
with pytest.warns(RuntimeWarning, match="stored on disk"):
write_forward_solution(fname_temp, fwd, overwrite=True)
fwd_read = read_forward_solution(fname_temp)
fwd_read = convert_forward_solution(
fwd_read, surf_ori=True, force_fixed=True, use_cps=True
)
assert repr(fwd_read)
assert isinstance(fwd_read, Forward)
assert is_fixed_orient(fwd_read)
assert_forward_allclose(fwd, fwd_read)
fwd = read_forward_solution(fname_meeg_grad)
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)
leadfield = fwd["sol"]["data"]
assert_equal(leadfield.shape, (n_channels, n_src / 3))
assert_equal(len(fwd["sol"]["row_names"]), n_channels)
assert_equal(len(fwd["info"]["chs"]), n_channels)
assert "dev_head_t" in fwd["info"]
assert "mri_head_t" in fwd
assert fwd["surf_ori"]
with pytest.warns(RuntimeWarning, match="stored on disk"):
write_forward_solution(fname_temp, fwd, overwrite=True)
fwd_read = read_forward_solution(fname_temp)
fwd_read = convert_forward_solution(
fwd_read, surf_ori=True, force_fixed=True, use_cps=True
)
assert repr(fwd_read)
assert isinstance(fwd_read, Forward)
assert is_fixed_orient(fwd_read)
assert_forward_allclose(fwd, fwd_read)
# test warnings on bad filenames
fwd = read_forward_solution(fname_meeg_grad)
fwd_badname = tmp_path / "test-bad-name.fif.gz"
with pytest.warns(RuntimeWarning, match="end with"):
write_forward_solution(fwd_badname, fwd)
with pytest.warns(RuntimeWarning, match="end with"):
read_forward_solution(fwd_badname)
fwd = read_forward_solution(fname_meeg)
write_forward_solution(fname_temp, fwd, overwrite=True)
fwd_read = read_forward_solution(fname_temp)
assert_forward_allclose(fwd, fwd_read)
h5py = pytest.importorskip("h5py")
pytest.importorskip("h5io")
fname_h5 = fname_temp.with_suffix(".h5")
fwd.save(fname_h5)
with h5py.File(fname_h5, "r"):
pass # just checks for hdf5-ness
fwd_read = read_forward_solution(fname_h5)
assert_forward_allclose(fwd, fwd_read)
@testing.requires_testing_data
def test_apply_forward():
"""Test projection of source space data to sensor space."""
start = 0
stop = 5
n_times = stop - start - 1
sfreq = 10.0
t_start = 0.123
fwd = read_forward_solution(fname_meeg)
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)
fwd = pick_types_forward(fwd, meg=True)
assert isinstance(fwd, Forward)
vertno = [fwd["src"][0]["vertno"], fwd["src"][1]["vertno"]]
stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times))
stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq)
gain_sum = np.sum(fwd["sol"]["data"], axis=1)
# Evoked
evoked = read_evokeds(fname_evoked, condition=0)
evoked.pick(picks="meg")
with (
_record_warnings(),
pytest.warns(RuntimeWarning, match="only .* positive values"),
):
evoked = apply_forward(fwd, stc, evoked.info, start=start, stop=stop)
data = evoked.data
times = evoked.times
# do some tests
assert_array_almost_equal(evoked.info["sfreq"], sfreq)
assert_array_almost_equal(np.sum(data, axis=1), n_times * gain_sum)
assert_array_almost_equal(times[0], t_start)
assert_array_almost_equal(times[-1], t_start + (n_times - 1) / sfreq)
# vector
stc_vec = VectorSourceEstimate(
fwd["source_nn"][:, :, np.newaxis] * stc.data[:, np.newaxis],
stc.vertices,
stc.tmin,
stc.tstep,
)
large_ctx = pytest.warns(RuntimeWarning, match="very large")
with large_ctx:
evoked_2 = apply_forward(fwd, stc_vec, evoked.info)
assert np.abs(evoked_2.data).mean() > 1e-5
assert_allclose(evoked.data, evoked_2.data, atol=1e-10)
# Raw
with large_ctx, pytest.warns(RuntimeWarning, match="only .* positive values"):
raw_proj = apply_forward_raw(fwd, stc, evoked.info, start=start, stop=stop)
data, times = raw_proj[:, :]
# do some tests
assert_array_almost_equal(raw_proj.info["sfreq"], sfreq)
assert_array_almost_equal(np.sum(data, axis=1), n_times * gain_sum)
atol = 1.0 / sfreq
assert_allclose(raw_proj.first_samp / sfreq, t_start, atol=atol)
assert_allclose(
raw_proj.last_samp / sfreq, t_start + (n_times - 1) / sfreq, atol=atol
)
@testing.requires_testing_data
def test_restrict_forward_to_stc(tmp_path):
"""Test restriction of source space to source SourceEstimate."""
start = 0
stop = 5
n_times = stop - start - 1
sfreq = 10.0
t_start = 0.123
fwd = read_forward_solution(fname_meeg)
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)
fwd = pick_types_forward(fwd, meg=True)
vertno = [fwd["src"][0]["vertno"][0:15], fwd["src"][1]["vertno"][0:5]]
stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times))
stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq)
fwd_out = restrict_forward_to_stc(fwd, stc)
assert isinstance(fwd_out, Forward)
assert_equal(fwd_out["sol"]["ncol"], 20)
assert_equal(fwd_out["src"][0]["nuse"], 15)
assert_equal(fwd_out["src"][1]["nuse"], 5)
assert_equal(fwd_out["src"][0]["vertno"], fwd["src"][0]["vertno"][0:15])
assert_equal(fwd_out["src"][1]["vertno"], fwd["src"][1]["vertno"][0:5])
fwd = read_forward_solution(fname_meeg)
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=False)
fwd = pick_types_forward(fwd, meg=True)
vertno = [fwd["src"][0]["vertno"][0:15], fwd["src"][1]["vertno"][0:5]]
stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times))
stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq)
fwd_out = restrict_forward_to_stc(fwd, stc)
assert_equal(fwd_out["sol"]["ncol"], 60)
assert_equal(fwd_out["src"][0]["nuse"], 15)
assert_equal(fwd_out["src"][1]["nuse"], 5)
assert_equal(fwd_out["src"][0]["vertno"], fwd["src"][0]["vertno"][0:15])
assert_equal(fwd_out["src"][1]["vertno"], fwd["src"][1]["vertno"][0:5])
# Test saving the restricted forward object. This only works if all fields
# are properly accounted for.
fname_copy = tmp_path / "copy-fwd.fif"
with pytest.warns(RuntimeWarning, match="stored on disk"):
write_forward_solution(fname_copy, fwd_out, overwrite=True)
fwd_out_read = read_forward_solution(fname_copy)
fwd_out_read = convert_forward_solution(
fwd_out_read, surf_ori=True, force_fixed=False
)
assert_forward_allclose(fwd_out, fwd_out_read)
@testing.requires_testing_data
def test_restrict_forward_to_label(tmp_path):
"""Test restriction of source space to label."""
fwd = read_forward_solution(fname_meeg)
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)
fwd = pick_types_forward(fwd, meg=True)
labels = ["Aud-lh", "Vis-rh"]
label_lh = read_label(label_path / (labels[0] + ".label"))
label_rh = read_label(label_path / (labels[1] + ".label"))
fwd_out = restrict_forward_to_label(fwd, [label_lh, label_rh])
src_sel_lh = np.intersect1d(fwd["src"][0]["vertno"], label_lh.vertices)
src_sel_lh = np.searchsorted(fwd["src"][0]["vertno"], src_sel_lh)
vertno_lh = fwd["src"][0]["vertno"][src_sel_lh]
nuse_lh = fwd["src"][0]["nuse"]
src_sel_rh = np.intersect1d(fwd["src"][1]["vertno"], label_rh.vertices)
src_sel_rh = np.searchsorted(fwd["src"][1]["vertno"], src_sel_rh)
vertno_rh = fwd["src"][1]["vertno"][src_sel_rh]
src_sel_rh += nuse_lh
assert_equal(fwd_out["sol"]["ncol"], len(src_sel_lh) + len(src_sel_rh))
assert_equal(fwd_out["src"][0]["nuse"], len(src_sel_lh))
assert_equal(fwd_out["src"][1]["nuse"], len(src_sel_rh))
assert_equal(fwd_out["src"][0]["vertno"], vertno_lh)
assert_equal(fwd_out["src"][1]["vertno"], vertno_rh)
fwd = read_forward_solution(fname_meeg)
fwd = pick_types_forward(fwd, meg=True)
labels = ["Aud-lh", "Vis-rh"]
label_lh = read_label(label_path / (labels[0] + ".label"))
label_rh = read_label(label_path / (labels[1] + ".label"))
fwd_out = restrict_forward_to_label(fwd, [label_lh, label_rh])
src_sel_lh = np.intersect1d(fwd["src"][0]["vertno"], label_lh.vertices)
src_sel_lh = np.searchsorted(fwd["src"][0]["vertno"], src_sel_lh)
vertno_lh = fwd["src"][0]["vertno"][src_sel_lh]
nuse_lh = fwd["src"][0]["nuse"]
src_sel_rh = np.intersect1d(fwd["src"][1]["vertno"], label_rh.vertices)
src_sel_rh = np.searchsorted(fwd["src"][1]["vertno"], src_sel_rh)
vertno_rh = fwd["src"][1]["vertno"][src_sel_rh]
src_sel_rh += nuse_lh
assert_equal(fwd_out["sol"]["ncol"], 3 * (len(src_sel_lh) + len(src_sel_rh)))
assert_equal(fwd_out["src"][0]["nuse"], len(src_sel_lh))
assert_equal(fwd_out["src"][1]["nuse"], len(src_sel_rh))
assert_equal(fwd_out["src"][0]["vertno"], vertno_lh)
assert_equal(fwd_out["src"][1]["vertno"], vertno_rh)
# Test saving the restricted forward object. This only works if all fields
# are properly accounted for.
fname_copy = tmp_path / "copy-fwd.fif"
write_forward_solution(fname_copy, fwd_out, overwrite=True)
fwd_out_read = read_forward_solution(fname_copy)
assert_forward_allclose(fwd_out, fwd_out_read)
@pytest.mark.parametrize("use_cps", [True, False])
@testing.requires_testing_data
def test_restrict_forward_to_label_cps(tmp_path, use_cps):
"""Test for gh-11689."""
label_lh = read_label(label_path / "Aud-lh.label")
fwd = read_forward_solution(fname_meeg)
convert_forward_solution(
fwd, surf_ori=True, force_fixed=False, copy=False, use_cps=use_cps
)
fwd = pick_types_forward(fwd, meg="mag")
fwd_out = restrict_forward_to_label(fwd, label_lh)
vert = fwd_out["src"][0]["vertno"][0]
assert fwd["surf_ori"]
assert not is_fixed_orient(fwd)
idx = list(fwd["src"][0]["vertno"]).index(vert)
assert idx == 126
go1 = fwd["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs1 = fwd["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()
assert fwd_out["surf_ori"]
assert not is_fixed_orient(fwd_out)
idx = list(fwd_out["src"][0]["vertno"]).index(vert)
assert idx == 0
go2 = fwd_out["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs2 = fwd_out["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()
assert_allclose(go2, go1)
assert_allclose(gs2, gs1)
# should be a no-op
convert_forward_solution(
fwd_out, surf_ori=True, force_fixed=False, copy=False, use_cps=use_cps
)
assert fwd_out["surf_ori"]
assert not is_fixed_orient(fwd_out)
assert list(fwd_out["src"][0]["vertno"]).index(vert) == 0
go3 = fwd_out["_orig_sol"][:, idx * 3 : idx * 3 + 3].copy()
gs3 = fwd_out["sol"]["data"][:, idx * 3 : idx * 3 + 3].copy()
assert_allclose(go3, go1)
assert_allclose(gs3, gs1)
@testing.requires_testing_data
@requires_mne
def test_average_forward_solution(tmp_path):
"""Test averaging forward solutions."""
fwd = read_forward_solution(fname_meeg)
# input not a list
pytest.raises(TypeError, average_forward_solutions, 1)
# list is too short
pytest.raises(ValueError, average_forward_solutions, [])
# negative weights
pytest.raises(ValueError, average_forward_solutions, [fwd, fwd], [-1, 0])
# all zero weights
pytest.raises(ValueError, average_forward_solutions, [fwd, fwd], [0, 0])
# weights not same length
pytest.raises(ValueError, average_forward_solutions, [fwd, fwd], [0, 0, 0])
# list does not only have all dict()
pytest.raises(TypeError, average_forward_solutions, [1, fwd])
# try an easy case
fwd_copy = average_forward_solutions([fwd])
assert isinstance(fwd_copy, Forward)
assert_array_equal(fwd["sol"]["data"], fwd_copy["sol"]["data"])
# modify a fwd solution, save it, use MNE to average with old one
fwd_copy["sol"]["data"] *= 0.5
fname_copy = str(tmp_path / "copy-fwd.fif")
write_forward_solution(fname_copy, fwd_copy, overwrite=True)
cmd = (
"mne_average_forward_solutions",
"--fwd",
fname_meeg,
"--fwd",
fname_copy,
"--out",
fname_copy,
)
run_subprocess(cmd)
# now let's actually do it, with one filename and one fwd
fwd_ave = average_forward_solutions([fwd, fwd_copy])
assert_array_equal(0.75 * fwd["sol"]["data"], fwd_ave["sol"]["data"])
# fwd_ave_mne = read_forward_solution(fname_copy)
# assert_array_equal(fwd_ave_mne['sol']['data'], fwd_ave['sol']['data'])
# with gradient
fwd = read_forward_solution(fname_meeg_grad)
fwd_ave = average_forward_solutions([fwd, fwd])
assert_forward_allclose(fwd, fwd_ave)
@testing.requires_testing_data
def test_priors():
"""Test prior computations."""
# Depth prior
fwd = read_forward_solution(fname_meeg)
assert not is_fixed_orient(fwd)
n_sources = fwd["nsource"]
info = read_info(fname_evoked)
depth_prior = compute_depth_prior(fwd, info, exp=0.8)
assert depth_prior.shape == (3 * n_sources,)
depth_prior = compute_depth_prior(fwd, info, exp=0.0)
assert_array_equal(depth_prior, 1.0)
with pytest.raises(ValueError, match='must be "whiten"'):
compute_depth_prior(fwd, info, limit_depth_chs="foo")
with pytest.raises(ValueError, match="noise_cov must be a Covariance"):
compute_depth_prior(fwd, info, limit_depth_chs="whiten")
fwd_fixed = convert_forward_solution(fwd, force_fixed=True)
depth_prior = compute_depth_prior(fwd_fixed, info=info)
assert depth_prior.shape == (n_sources,)
# Orientation prior
orient_prior = compute_orient_prior(fwd, 1.0)
assert_array_equal(orient_prior, 1.0)
orient_prior = compute_orient_prior(fwd_fixed, 0.0)
assert_array_equal(orient_prior, 1.0)
with pytest.raises(ValueError, match="oriented in surface coordinates"):
compute_orient_prior(fwd, 0.5)
fwd_surf_ori = convert_forward_solution(fwd, surf_ori=True)
orient_prior = compute_orient_prior(fwd_surf_ori, 0.5)
assert all(np.isin(orient_prior, (0.5, 1.0)))
with pytest.raises(ValueError, match="between 0 and 1"):
compute_orient_prior(fwd_surf_ori, -0.5)
with pytest.raises(ValueError, match="with fixed orientation"):
compute_orient_prior(fwd_fixed, 0.5)
@testing.requires_testing_data
def test_equalize_channels():
"""Test equalization of channels for instances of Forward."""
fwd1 = read_forward_solution(fname_meeg)
pick_channels_forward(fwd1, include=["EEG 001", "EEG 002", "EEG 003"], copy=False)
fwd2 = pick_channels_forward(fwd1, include=["EEG 002", "EEG 001"], ordered=True)
fwd1, fwd2 = equalize_channels([fwd1, fwd2])
assert fwd1.ch_names == ["EEG 001", "EEG 002"]
assert fwd2.ch_names == ["EEG 001", "EEG 002"]