"""Test check utilities."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os
import sys
from pathlib import Path
import numpy as np
import pytest
import mne
from mne import pick_channels_cov, read_vectorview_selection
from mne._fiff.pick import _picks_to_idx
from mne.datasets import testing
from mne.utils import (
Bunch,
_check_ch_locs,
_check_fname,
_check_info_inv,
_check_option,
_check_range,
_check_sphere,
_check_subject,
_on_missing,
_path_like,
_record_warnings,
_safe_input,
_soft_import,
_suggest,
_validate_type,
catch_logging,
check_fname,
check_random_state,
check_version,
)
data_path = testing.data_path(download=False)
base_dir = data_path / "MEG" / "sample"
fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif"
fname_event = base_dir / "sample_audvis_trunc_raw-eve.fif"
fname_fwd = base_dir / "sample_audvis_trunc-meg-vol-7-fwd.fif"
fname_mgz = data_path / "subjects" / "sample" / "mri" / "aseg.mgz"
reject = dict(grad=4000e-13, mag=4e-12)
@testing.requires_testing_data
def test_check(tmp_path):
"""Test checking functions."""
pytest.raises(ValueError, check_random_state, "foo")
pytest.raises(TypeError, _check_fname, 1)
_check_fname(Path("./foo"))
fname = tmp_path / "foo"
with open(fname, "wb"):
pass
assert fname.is_file()
_check_fname(fname, overwrite="read", must_exist=True)
orig_perms = os.stat(fname).st_mode
os.chmod(fname, 0)
if not sys.platform.startswith("win"):
with pytest.raises(PermissionError, match="read permissions"):
_check_fname(fname, overwrite="read", must_exist=True)
os.chmod(fname, orig_perms)
os.remove(fname)
assert not fname.is_file()
pytest.raises(OSError, check_fname, "foo", "tets-dip.x", (), (".fif",))
pytest.raises(ValueError, _check_subject, None, None)
pytest.raises(TypeError, _check_subject, None, 1)
pytest.raises(TypeError, _check_subject, 1, None)
# smoke tests for permitted types
check_random_state(None).choice(1)
check_random_state(0).choice(1)
check_random_state(np.random.RandomState(0)).choice(1)
check_random_state(np.random.default_rng(0)).choice(1)
@testing.requires_testing_data
@pytest.mark.parametrize(
"suffix",
("_meg.fif", "_eeg.fif", "_ieeg.fif", "_meg.fif.gz", "_eeg.fif.gz", "_ieeg.fif.gz"),
)
def test_check_fname_suffixes(suffix, tmp_path):
"""Test checking for valid filename suffixes."""
new_fname = tmp_path / fname_raw.name.replace("_raw.fif", suffix)
raw = mne.io.read_raw_fif(fname_raw).crop(0, 0.1)
raw.save(new_fname)
mne.io.read_raw_fif(new_fname)
def _get_data():
"""Read in data used in tests."""
# read forward model
forward = mne.read_forward_solution(fname_fwd)
# read data
raw = mne.io.read_raw_fif(fname_raw, preload=True)
events = mne.read_events(fname_event)
event_id, tmin, tmax = 1, -0.1, 0.15
# decimate for speed
left_temporal_channels = read_vectorview_selection("Left-temporal")
picks = mne.pick_types(raw.info, meg=True, selection=left_temporal_channels)
picks = picks[::2]
raw.pick([raw.ch_names[ii] for ii in picks])
del picks
raw.info.normalize_proj() # avoid projection warnings
epochs = mne.Epochs(
raw,
events,
event_id,
tmin,
tmax,
proj=True,
baseline=(None, 0.0),
preload=True,
reject=reject,
)
noise_cov = mne.compute_covariance(epochs, tmin=None, tmax=0.0)
data_cov = mne.compute_covariance(epochs, tmin=0.01, tmax=0.15)
return epochs, data_cov, noise_cov, forward
@testing.requires_testing_data
def test_check_info_inv():
"""Test checks for common channels across fwd model and cov matrices."""
epochs, data_cov, noise_cov, forward = _get_data()
# make sure same channel lists exist in data to make testing life easier
assert epochs.info["ch_names"] == data_cov.ch_names
assert epochs.info["ch_names"] == noise_cov.ch_names
# check whether bad channels get excluded from the channel selection
# info
info_bads = epochs.info.copy()
info_bads["bads"] = info_bads["ch_names"][1:3] # include two bad channels
picks = _check_info_inv(info_bads, forward, noise_cov=noise_cov)
assert [1, 2] not in picks
# covariance matrix
data_cov_bads = data_cov.copy()
data_cov_bads["bads"] = [data_cov_bads.ch_names[0]]
picks = _check_info_inv(epochs.info, forward, data_cov=data_cov_bads)
assert 0 not in picks
# noise covariance matrix
noise_cov_bads = noise_cov.copy()
noise_cov_bads["bads"] = [noise_cov_bads.ch_names[1]]
picks = _check_info_inv(epochs.info, forward, noise_cov=noise_cov_bads)
assert 1 not in picks
# test whether reference channels get deleted
info_ref = epochs.info.copy()
info_ref["chs"][0]["kind"] = 301 # pretend to have a ref channel
picks = _check_info_inv(info_ref, forward, noise_cov=noise_cov)
assert 0 not in picks
# pick channels in all inputs and make sure common set is returned
epochs.pick([epochs.ch_names[ii] for ii in range(10)])
data_cov = pick_channels_cov(
data_cov, include=[data_cov.ch_names[ii] for ii in range(5, 20)]
)
noise_cov = pick_channels_cov(
noise_cov, include=[noise_cov.ch_names[ii] for ii in range(7, 12)]
)
with catch_logging() as log:
picks = _check_info_inv(
epochs.info, forward, noise_cov=noise_cov, data_cov=data_cov, verbose=True
)
assert list(range(7, 10)) == picks
# make sure to inform the user that 7 channels were dropped
# (there are 10 channels in epochs but only 3 were picked)
log = log.getvalue()
assert "Excluding 7 channel(s) missing" in log
def test_check_option():
"""Test checking the value of a parameter against a list of options."""
allowed_values = ["valid", "good", "ok"]
# Value is allowed
assert _check_option("option", "valid", allowed_values)
assert _check_option("option", "good", allowed_values)
assert _check_option("option", "ok", allowed_values)
assert _check_option("option", "valid", ["valid"])
# Check error message for invalid value
msg = (
"Invalid value for the 'option' parameter. Allowed values are "
"'valid', 'good', and 'ok', but got 'bad' instead."
)
with pytest.raises(ValueError, match=msg):
assert _check_option("option", "bad", allowed_values)
# Special error message if only one value is allowed
msg = (
"Invalid value for the 'option' parameter. The only allowed value "
"is 'valid', but got 'bad' instead."
)
with pytest.raises(ValueError, match=msg):
assert _check_option("option", "bad", ["valid"])
def test_path_like():
"""Test _path_like()."""
str_path = str(base_dir)
pathlib_path = Path(base_dir)
no_path = dict(foo="bar")
assert _path_like(str_path) is True
assert _path_like(pathlib_path) is True
assert _path_like(no_path) is False
def test_validate_type():
"""Test _validate_type."""
_validate_type(1, "int-like")
with pytest.raises(TypeError, match="int-like"):
_validate_type(False, "int-like")
_validate_type([1, 2, 3], "array-like")
_validate_type((1, 2, 3), "array-like")
_validate_type({1, 2, 3}, "array-like")
with pytest.raises(TypeError, match="array-like"):
_validate_type("123", "array-like") # a string is not array-like
def test_check_range():
"""Test _check_range."""
_check_range(10, 1, 100, "value")
with pytest.raises(ValueError, match="must be between"):
_check_range(0, 1, 10, "value")
with pytest.raises(ValueError, match="must be between"):
_check_range(1, 1, 10, "value", False, False)
@testing.requires_testing_data
def test_suggest():
"""Test suggestions."""
pytest.importorskip("nibabel")
names = mne.get_volume_labels_from_aseg(fname_mgz)
sug = _suggest("", names)
assert sug == "" # nothing
sug = _suggest("Left-cerebellum", names)
assert sug == " Did you mean 'Left-Cerebellum-Cortex'?"
sug = _suggest("Cerebellum-Cortex", names)
assert (
sug
== " Did you mean one of ['Left-Cerebellum-Cortex', 'Right-Cerebellum-Cortex', 'Left-Cerebral-Cortex']?" # noqa: E501
)
def test_on_missing():
"""Test _on_missing."""
msg = "test"
with pytest.raises(ValueError, match=msg):
_on_missing("raise", msg)
with pytest.warns(RuntimeWarning, match=msg):
_on_missing("warn", msg)
_on_missing("ignore", msg)
with pytest.raises(
ValueError, match="Invalid value for the 'on_missing' parameter"
):
_on_missing("foo", msg)
def _matlab_input(msg):
raise EOFError()
def test_safe_input(monkeypatch):
"""Test _safe_input."""
monkeypatch.setattr(mne.utils.check, "input", _matlab_input)
with pytest.raises(RuntimeError, match="Could not use input"):
_safe_input("whatever", alt="nothing")
assert _safe_input("whatever", use="nothing") == "nothing"
@testing.requires_testing_data
def test_check_ch_locs():
"""Test _check_ch_locs behavior."""
info = mne.io.read_info(fname_raw)
assert _check_ch_locs(info=info)
for picks in ([0], [0, 1], None):
assert _check_ch_locs(info=info, picks=picks)
for ch_type in ("meg", "mag", "grad", "eeg"):
assert _check_ch_locs(info=info, ch_type=ch_type)
# drop locations for EEG
picks_eeg = _picks_to_idx(info=info, picks="eeg")
for idx in picks_eeg:
info["chs"][idx]["loc"][:3] = np.nan
# EEG tests should fail now
assert _check_ch_locs(info=info, picks=picks_eeg) is False
assert _check_ch_locs(info=info, ch_type="eeg") is False
# tests for other (and "all") channels should still pass
assert _check_ch_locs(info=info)
assert _check_ch_locs(info=info, ch_type="mag")
# Check a bunch of version schemes as of 2022/03/01
# We don't have to get this 100% generalized, but it would be nice if all
# of these worked.
@pytest.mark.parametrize(
"version, want, have_unstripped",
[
# test some dev cases
("1.23.0.dev0+782.g1168868df6", "1.23", False), # NumPy
("1.9.0.dev0+1485.b06254e", "1.9", False), # SciPy
("3.6.0.dev1651+g30d6161406", "3.6", False), # matplotlib
("1.1.dev0", "1.1", False), # sklearn
("0.56.0dev0+39.gef1ba4c10", "0.56", False), # numba
("9.1.0.rc1", "9.1", False), # VTK
("0.3dev0", "0.3", False), # mne-connectivity
("0.2.2.dev0", "0.2.2", False), # mne-qt-browser
("3.2.2+150.g1e93bd5d", "3.2.2", True), # nibabel
# test some stable cases
("1.2.3", "1.2.3", True),
("1.2", "1.2", True),
("1", "1", True),
],
)
def test_strip_dev(version, want, have_unstripped, monkeypatch):
"""Test that stripping dev works."""
monkeypatch.setattr(
mne.utils.check, "import_module", lambda x: Bunch(__version__=version)
)
got_have_unstripped, same_version = check_version(
version, want, strip=False, return_version=True
)
assert same_version == version
assert got_have_unstripped is have_unstripped
have, simpler_version = check_version(
"foo", want, return_version=True
) # strip=True is the default
assert have, (simpler_version, version)
def looks_stable(version):
try:
[int(x) for x in version.split(".")]
except ValueError:
return False
else:
return True
if looks_stable(version):
assert "dev" not in version
assert "rc" not in version
assert simpler_version == version
else:
assert simpler_version != version
assert "dev" not in simpler_version
assert "rc" not in simpler_version
assert not simpler_version.endswith(".")
assert looks_stable(simpler_version)
@testing.requires_testing_data
def test_check_sphere_verbose():
"""Test that verbose is handled properly in _check_sphere."""
info = mne.io.read_info(fname_raw)
with info._unlock():
info["dig"] = info["dig"][:20]
with _record_warnings(), pytest.warns(RuntimeWarning, match="may be inaccurate"):
_check_sphere("auto", info)
with mne.use_log_level("error"):
_check_sphere("auto", info)
def test_soft_import():
"""Test _soft_import."""
with pytest.raises(RuntimeError, match=r".* the module mne>=999 \(found version.*"):
_soft_import("mne", "testing", min_version="999")