# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from mne import pick_types
from mne._fiff.proj import activate_proj, make_projector
from mne.datasets import testing
from mne.io import read_raw_ctf, read_raw_fif
from mne.preprocessing.ssp import compute_proj_ecg, compute_proj_eog
from mne.utils import _record_warnings
data_path = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = data_path / "test_raw.fif"
dur_use = 5.0
eog_times = np.array([0.5, 2.3, 3.6, 14.5])
ctf_fname = testing.data_path(download=False) / "CTF" / "testdata_ctf.ds"
@pytest.fixture()
def short_raw():
"""Create a short, picked raw instance."""
raw = read_raw_fif(raw_fname).crop(0, 7).pick(["meg", "eeg", "eog"])
raw.pick(raw.ch_names[:306:10] + raw.ch_names[306:]).load_data()
raw.info.normalize_proj()
return raw
@pytest.mark.parametrize("average", (True, False))
def test_compute_proj_ecg(short_raw, average):
"""Test computation of ECG SSP projectors."""
raw = short_raw
# For speed, let's not filter here (must also not reject then)
with pytest.warns(RuntimeWarning, match="Attenuation"):
projs, events = compute_proj_ecg(
raw,
n_mag=2,
n_grad=2,
n_eeg=2,
ch_name="MEG 1531",
bads=["MEG 2443"],
average=average,
avg_ref=True,
no_proj=True,
l_freq=None,
h_freq=None,
reject=None,
tmax=dur_use,
qrs_threshold=0.5,
filter_length=1000,
)
assert len(projs) == 7
# heart rate at least 0.5 Hz, but less than 3 Hz
assert events.shape[0] > 0.5 * dur_use and events.shape[0] < 3 * dur_use
ssp_ecg = [proj for proj in projs if proj["desc"].startswith("ECG")]
# check that the first principal component have a certain minimum
ssp_ecg = [proj for proj in ssp_ecg if "PCA-01" in proj["desc"]]
thresh_eeg, thresh_axial, thresh_planar = 0.9, 0.3, 0.1
for proj in ssp_ecg:
if "planar" in proj["desc"]:
assert proj["explained_var"] > thresh_planar
elif "axial" in proj["desc"]:
assert proj["explained_var"] > thresh_axial
elif "eeg" in proj["desc"]:
assert proj["explained_var"] > thresh_eeg
# XXX: better tests
# without setting a bad channel, this should throw a warning
# (first with a call that makes sure we copy the mutable default "reject")
with pytest.warns(RuntimeWarning, match="longer than the signal"):
compute_proj_ecg(raw.copy().pick("mag"), l_freq=None, h_freq=None)
with _record_warnings(), pytest.warns(RuntimeWarning, match="No good epochs found"):
projs, events, drop_log = compute_proj_ecg(
raw,
n_mag=2,
n_grad=2,
n_eeg=2,
ch_name="MEG 1531",
bads=[],
average=average,
avg_ref=True,
no_proj=True,
l_freq=None,
h_freq=None,
tmax=dur_use,
return_drop_log=True,
# XXX can be removed once
# XXX https://github.com/mne-tools/mne-python/issues/9273
# XXX has been resolved:
qrs_threshold=1e-15,
)
assert projs == []
assert len(events) == len(drop_log)
@pytest.mark.parametrize("average", [True, False])
def test_compute_proj_eog(average, short_raw):
"""Test computation of EOG SSP projectors."""
raw = short_raw
n_projs_init = len(raw.info["projs"])
with pytest.warns(RuntimeWarning, match="Attenuation"):
projs, events = compute_proj_eog(
raw,
n_mag=2,
n_grad=2,
n_eeg=2,
bads=["MEG 2443"],
average=average,
avg_ref=True,
no_proj=False,
l_freq=None,
h_freq=None,
reject=None,
tmax=dur_use,
filter_length=1000,
)
assert len(projs) == (7 + n_projs_init)
assert np.abs(events.shape[0] - np.sum(np.less(eog_times, dur_use))) <= 1
ssp_eog = [proj for proj in projs if proj["desc"].startswith("EOG")]
# check that the first principal component have a certain minimum
ssp_eog = [proj for proj in ssp_eog if "PCA-01" in proj["desc"]]
thresh_eeg, thresh_axial, thresh_planar = 0.9, 0.3, 0.1
for proj in ssp_eog:
if "planar" in proj["desc"]:
assert proj["explained_var"] > thresh_planar
elif "axial" in proj["desc"]:
assert proj["explained_var"] > thresh_axial
elif "eeg" in proj["desc"]:
assert proj["explained_var"] > thresh_eeg
# XXX: better tests
with _record_warnings(), pytest.warns(RuntimeWarning, match="longer"):
projs, events = compute_proj_eog(
raw,
n_mag=2,
n_grad=2,
n_eeg=2,
average=average,
bads=[],
avg_ref=True,
no_proj=False,
l_freq=None,
h_freq=None,
tmax=dur_use,
)
assert projs == []
raw._data[raw.ch_names.index("EOG 061"), :] = 1.0
with (
_record_warnings(),
pytest.warns(RuntimeWarning, match="filter.*longer than the signal"),
):
projs, events = compute_proj_eog(raw=raw, tmax=dur_use, ch_name="EOG 061")
@pytest.mark.slowtest # can be slow on OSX
def test_compute_proj_parallel(short_raw):
"""Test computation of ExG projectors using parallelization."""
short_raw = short_raw.copy().pick(("eeg", "eog")).resample(100)
raw = short_raw.copy()
with pytest.warns(RuntimeWarning, match="Attenuation"):
projs, _ = compute_proj_eog(
raw,
n_eeg=2,
bads=raw.ch_names[1:2],
average=False,
avg_ref=True,
no_proj=False,
n_jobs=None,
l_freq=None,
h_freq=None,
reject=None,
tmax=dur_use,
filter_length=100,
)
raw_2 = short_raw.copy()
with _record_warnings(), pytest.warns(RuntimeWarning, match="Attenuation"):
projs_2, _ = compute_proj_eog(
raw_2,
n_eeg=2,
bads=raw.ch_names[1:2],
average=False,
avg_ref=True,
no_proj=False,
n_jobs=2,
l_freq=None,
h_freq=None,
reject=None,
tmax=dur_use,
filter_length=100,
)
projs = activate_proj(projs)
projs_2 = activate_proj(projs_2)
projs, _, _ = make_projector(projs, raw_2.info["ch_names"], bads=["MEG 2443"])
projs_2, _, _ = make_projector(projs_2, raw_2.info["ch_names"], bads=["MEG 2443"])
assert_array_almost_equal(projs, projs_2, 10)
def _check_projs_for_expected_channels(projs, n_mags, n_grads, n_eegs):
assert projs is not None
for p in projs:
if "planar" in p["desc"]:
assert len(p["data"]["col_names"]) == n_grads
elif "axial" in p["desc"]:
assert len(p["data"]["col_names"]) == n_mags
elif "eeg" in p["desc"]:
assert len(p["data"]["col_names"]) == n_eegs
@pytest.mark.slowtest # can be slow on OSX
@testing.requires_testing_data
def test_compute_proj_ctf():
"""Test to show that projector code completes on CTF data."""
raw = read_raw_ctf(ctf_fname, preload=True)
# expected channels per projector type
mag_picks = pick_types(raw.info, meg="mag", ref_meg=False, exclude="bads")[::10]
n_mags = len(mag_picks)
grad_picks = pick_types(raw.info, meg="grad", ref_meg=False, exclude="bads")[::10]
n_grads = len(grad_picks)
eeg_picks = pick_types(
raw.info, meg=False, eeg=True, ref_meg=False, exclude="bads"
)[2::3]
n_eegs = len(eeg_picks)
ref_picks = pick_types(raw.info, meg=False, ref_meg=True)
raw.pick(np.sort(np.concatenate([mag_picks, grad_picks, eeg_picks, ref_picks])))
del mag_picks, grad_picks, eeg_picks, ref_picks
# Test with and without gradient compensation
raw.apply_gradient_compensation(0)
n_projs_init = len(raw.info["projs"])
with pytest.warns(RuntimeWarning, match="Attenuation"):
projs, _ = compute_proj_eog(
raw,
n_mag=2,
n_grad=2,
n_eeg=2,
average=True,
ch_name="EEG059",
avg_ref=True,
no_proj=False,
l_freq=None,
h_freq=None,
reject=None,
tmax=dur_use,
filter_length=1000,
)
_check_projs_for_expected_channels(projs, n_mags, n_grads, n_eegs)
assert len(projs) == (5 + n_projs_init)
raw.apply_gradient_compensation(1)
with pytest.warns(RuntimeWarning, match="Attenuation"):
projs, _ = compute_proj_ecg(
raw,
n_mag=1,
n_grad=1,
n_eeg=2,
average=True,
ch_name="EEG059",
avg_ref=True,
no_proj=False,
l_freq=None,
h_freq=None,
reject=None,
tmax=dur_use,
filter_length=1000,
)
_check_projs_for_expected_channels(projs, n_mags, n_grads, n_eegs)
assert len(projs) == (4 + n_projs_init)