# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from copy import deepcopy
from datetime import date
from io import StringIO
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from mne import pick_types, read_cov, read_evokeds
from mne._fiff.pick import _picks_by_type
from mne.epochs import make_fixed_length_epochs
from mne.fixes import _eye_array
from mne.io import read_raw_fif
from mne.time_frequency import tfr_morlet
from mne.utils import (
_PCA,
_apply_scaling_array,
_apply_scaling_cov,
_array_equal_nan,
_custom_lru_cache,
_date_to_julian,
_freq_mask,
_get_inst_data,
_julian_to_date,
_reg_pinv,
_replace_md5,
_ReuseCycle,
_time_mask,
_undo_scaling_array,
_undo_scaling_cov,
compute_corr,
create_slices,
grand_average,
hashfunc,
numerics,
object_diff,
object_hash,
object_size,
random_permutation,
sum_squared,
)
from mne.utils.numerics import _LRU_CACHE_MAXSIZES, _LRU_CACHES
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
fname_raw = base_dir / "test_raw.fif"
ave_fname = base_dir / "test-ave.fif"
cov_fname = base_dir / "test-cov.fif"
def test_get_inst_data():
"""Test _get_inst_data."""
raw = read_raw_fif(fname_raw)
raw.crop(tmax=1.0)
assert_array_equal(_get_inst_data(raw), raw._data)
raw.pick(raw.ch_names[:2])
epochs = make_fixed_length_epochs(raw, 0.5)
assert_array_equal(_get_inst_data(epochs), epochs._data)
evoked = epochs.average()
assert_array_equal(_get_inst_data(evoked), evoked.data)
evoked.crop(tmax=0.1)
picks = list(range(2))
freqs = [50.0, 55.0]
n_cycles = 3
tfr = tfr_morlet(evoked, freqs, n_cycles, return_itc=False, picks=picks)
assert_array_equal(_get_inst_data(tfr), tfr.data)
pytest.raises(TypeError, _get_inst_data, "foo")
def test_hashfunc(tmp_path):
"""Test md5/sha1 hash calculations."""
fname1 = tmp_path / "foo"
fname2 = tmp_path / "bar"
with open(fname1, "wb") as fid:
fid.write(b"abcd")
with open(fname2, "wb") as fid:
fid.write(b"efgh")
for hash_type in ("md5", "sha1"):
hash1 = hashfunc(fname1, hash_type=hash_type)
hash1_ = hashfunc(fname1, 1, hash_type=hash_type)
hash2 = hashfunc(fname2, hash_type=hash_type)
hash2_ = hashfunc(fname2, 1024, hash_type=hash_type)
assert hash1 == hash1_
assert hash2 == hash2_
assert hash1 != hash2
def test_sum_squared():
"""Test optimized sum of squares."""
X = np.random.RandomState(0).randint(0, 50, (3, 3))
assert np.sum(X**2) == sum_squared(X)
def test_compute_corr():
"""Test Anscombe's Quartett."""
x = np.array([10, 8, 13, 9, 11, 14, 6, 4, 12, 7, 5])
y = np.array(
[
[8.04, 6.95, 7.58, 8.81, 8.33, 9.96, 7.24, 4.26, 10.84, 4.82, 5.68],
[9.14, 8.14, 8.74, 8.77, 9.26, 8.10, 6.13, 3.10, 9.13, 7.26, 4.74],
[7.46, 6.77, 12.74, 7.11, 7.81, 8.84, 6.08, 5.39, 8.15, 6.42, 5.73],
[8, 8, 8, 8, 8, 8, 8, 19, 8, 8, 8],
[6.58, 5.76, 7.71, 8.84, 8.47, 7.04, 5.25, 12.50, 5.56, 7.91, 6.89],
]
)
r = compute_corr(x, y.T)
r2 = np.array([np.corrcoef(x, y[i])[0, 1] for i in range(len(y))])
assert_allclose(r, r2)
pytest.raises(ValueError, compute_corr, [1, 2], [])
def test_create_slices():
"""Test checking the create of time create_slices."""
# Test that create_slices default provide an empty list
assert create_slices(0, 0) == []
# Test that create_slice return correct number of slices
assert len(create_slices(0, 100)) == 100
# Test with non-zero start parameters
assert len(create_slices(50, 100)) == 50
# Test slices' length with non-zero start and window_width=2
assert len(create_slices(0, 100, length=2)) == 50
# Test slices' length with manual slice separation
assert len(create_slices(0, 100, step=10)) == 10
# Test slices' within length for non-consecutive samples
assert len(create_slices(0, 500, length=50, step=10)) == 46
# Test that slices elements start, stop and step correctly
slices = create_slices(0, 10)
assert slices[0].start == 0
assert slices[0].step == 1
assert slices[0].stop == 1
assert slices[-1].stop == 10
# Same with larger window width
slices = create_slices(0, 9, length=3)
assert slices[0].start == 0
assert slices[0].step == 1
assert slices[0].stop == 3
assert slices[-1].stop == 9
# Same with manual slices' separation
slices = create_slices(0, 9, length=3, step=1)
assert len(slices) == 7
assert slices[0].step == 1
assert slices[0].stop == 3
assert slices[-1].start == 6
assert slices[-1].stop == 9
def test_time_mask():
"""Test safe time masking."""
N = 10
x = np.arange(N).astype(float)
assert _time_mask(x, 0, N - 1).sum() == N
assert _time_mask(x - 1e-10, 0, N - 1, sfreq=1000.0).sum() == N
assert _time_mask(x - 1e-10, None, N - 1, sfreq=1000.0).sum() == N
assert _time_mask(x - 1e-10, None, None, sfreq=1000.0).sum() == N
assert _time_mask(x - 1e-10, -np.inf, None, sfreq=1000.0).sum() == N
assert _time_mask(x - 1e-10, None, np.inf, sfreq=1000.0).sum() == N
# non-uniformly spaced inputs
x = np.array([4, 10])
assert _time_mask(x[:1], tmin=10, sfreq=1, raise_error=False).sum() == 0
assert _time_mask(x[:1], tmin=11, tmax=12, sfreq=1, raise_error=False).sum() == 0
assert _time_mask(x, tmin=10, sfreq=1).sum() == 1
assert _time_mask(x, tmin=6, sfreq=1).sum() == 1
assert _time_mask(x, tmin=5, sfreq=1).sum() == 1
assert _time_mask(x, tmin=4.5001, sfreq=1).sum() == 1
assert _time_mask(x, tmin=4.4999, sfreq=1).sum() == 2
assert _time_mask(x, tmin=4, sfreq=1).sum() == 2
# degenerate cases
with pytest.raises(ValueError, match="No samples remain"):
_time_mask(x[:1], tmin=11, tmax=12)
with pytest.raises(ValueError, match="must be less than or equal to tmax"):
_time_mask(x[:1], tmin=10, sfreq=1)
def test_freq_mask():
"""Test safe frequency masking."""
N = 10
x = np.arange(N).astype(float)
assert _freq_mask(x, 1000.0, fmin=0, fmax=N - 1).sum() == N
assert _freq_mask(x - 1e-10, 1000.0, fmin=0, fmax=N - 1).sum() == N
assert _freq_mask(x - 1e-10, 1000.0, fmin=None, fmax=N - 1).sum() == N
assert _freq_mask(x - 1e-10, 1000.0, fmin=None, fmax=None).sum() == N
assert _freq_mask(x - 1e-10, 1000.0, fmin=-np.inf, fmax=None).sum() == N
assert _freq_mask(x - 1e-10, 1000.0, fmin=None, fmax=np.inf).sum() == N
# non-uniformly spaced inputs
x = np.array([4, 10])
assert _freq_mask(x[:1], 1, fmin=10, raise_error=False).sum() == 0
assert _freq_mask(x[:1], 1, fmin=11, fmax=12, raise_error=False).sum() == 0
assert _freq_mask(x, sfreq=1, fmin=10).sum() == 1
assert _freq_mask(x, sfreq=1, fmin=6).sum() == 1
assert _freq_mask(x, sfreq=1, fmin=5).sum() == 1
assert _freq_mask(x, sfreq=1, fmin=4.5001).sum() == 1
assert _freq_mask(x, sfreq=1, fmin=4.4999).sum() == 2
assert _freq_mask(x, sfreq=1, fmin=4).sum() == 2
# degenerate cases
with pytest.raises(ValueError, match="sfreq can not be None"):
_freq_mask(x[:1], sfreq=None, fmin=3, fmax=5)
with pytest.raises(ValueError, match="No frequencies remain"):
_freq_mask(x[:1], sfreq=1, fmin=11, fmax=12)
with pytest.raises(ValueError, match="must be less than or equal to fmax"):
_freq_mask(x[:1], sfreq=1, fmin=10)
def test_random_permutation():
"""Test random permutation function."""
n_samples = 10
random_state = 42
python_randperm = random_permutation(n_samples, random_state)
# matlab output when we execute rng(42), randperm(10)
matlab_randperm = np.array([7, 6, 5, 1, 4, 9, 10, 3, 8, 2])
assert_array_equal(python_randperm, matlab_randperm - 1)
def test_cov_scaling():
"""Test rescaling covs."""
evoked = read_evokeds(ave_fname, condition=0, baseline=(None, 0), proj=True)
cov = read_cov(cov_fname)["data"]
cov2 = read_cov(cov_fname)["data"]
assert_array_equal(cov, cov2)
evoked.pick(
[evoked.ch_names[k] for k in pick_types(evoked.info, meg=True, eeg=True)]
)
picks_list = _picks_by_type(evoked.info)
scalings = dict(mag=1e15, grad=1e13, eeg=1e6)
_apply_scaling_cov(cov2, picks_list, scalings=scalings)
_apply_scaling_cov(cov, picks_list, scalings=scalings)
assert_array_equal(cov, cov2)
assert cov.max() > 1
_undo_scaling_cov(cov2, picks_list, scalings=scalings)
_undo_scaling_cov(cov, picks_list, scalings=scalings)
assert_array_equal(cov, cov2)
assert cov.max() < 1
data = evoked.data.copy()
_apply_scaling_array(data, picks_list, scalings=scalings)
_undo_scaling_array(data, picks_list, scalings=scalings)
assert_allclose(data, evoked.data, atol=1e-20)
@pytest.mark.parametrize("ndim", (2, 3))
def test_reg_pinv(ndim):
"""Test regularization and inversion of covariance matrix."""
# create rank-deficient array
a = np.array([[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
for _ in range(ndim - 2):
a = a[np.newaxis]
# Test if rank-deficient matrix without regularization throws
# specific warning
with pytest.warns(RuntimeWarning, match="deficient"):
_reg_pinv(a, reg=0.0)
# Test inversion with explicit rank
a_inv_np = np.linalg.pinv(a, hermitian=True)
a_inv_mne, loading_factor, rank = _reg_pinv(a, rank=2)
assert loading_factor == 0
assert rank == 2
assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)
# Test inversion with automatic rank detection
a_inv_mne, _, estimated_rank = _reg_pinv(a, rank=None)
assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)
assert estimated_rank == 2
# Test adding regularization
a_inv_mne, loading_factor, estimated_rank = _reg_pinv(a, reg=2)
# Since A has a diagonal of all ones, loading_factor should equal the
# regularization parameter
assert loading_factor == 2
# The estimated rank should be that of the non-regularized matrix
assert estimated_rank == 2
# Test result against the NumPy version
a_inv_np = np.linalg.pinv(a + loading_factor * np.eye(3), hermitian=True)
assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)
# Test setting rcond
a_inv_np = np.linalg.pinv(a, rcond=0.5)
a_inv_mne, _, estimated_rank = _reg_pinv(a, rcond=0.5)
assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)
assert estimated_rank == 1
# Test inverting an all zero cov
a_inv, loading_factor, estimated_rank = _reg_pinv(np.zeros((3, 3)), reg=2)
assert_array_equal(a_inv, 0)
assert loading_factor == 0
assert estimated_rank == 0
def test_object_size():
"""Test object size estimation."""
assert object_size(np.ones(10, np.float32)) < object_size(np.ones(10, np.float64))
for lower, upper, obj in (
(0, 60, ""),
(0, 30, 1),
(0, 30, 1.0),
(0, 70, "foo"),
(0, 150, np.ones(0)),
(0, 150, np.int32(1)),
(150, 500, np.ones(20)),
(30, 400, dict()),
(400, 1000, dict(a=np.ones(50))),
(200, 900, _eye_array(20, format="csc")),
(200, 900, _eye_array(20, format="csr")),
):
size = object_size(obj)
assert lower < size < upper, f"{lower} < {size} < {upper}:\n{obj}"
# views work properly
x = dict(a=1)
assert object_size(x) < 1000
x["a"] = np.ones(100000, float)
nb = x["a"].nbytes
sz = object_size(x)
assert nb < sz < nb * 1.01
x["b"] = x["a"]
sz = object_size(x)
assert nb < sz < nb * 1.01
x["b"] = x["a"].view()
x["b"].flags.writeable = False
assert x["a"].flags.writeable
sz = object_size(x)
assert nb < sz < nb * 1.01
def test_object_diff_with_nan():
"""Test object diff can handle NaNs."""
d0 = np.array([1, np.nan, 0])
d1 = np.array([1, np.nan, 0])
d2 = np.array([np.nan, 1, 0])
assert object_diff(d0, d1) == ""
assert object_diff(d0, d2) != ""
assert object_diff(np.nan, np.nan) == ""
assert object_diff(np.nan, 3.5) == " value mismatch (nan, 3.5)\n"
def test_hash():
"""Test dictionary hashing and comparison functions."""
# does hashing all of these types work:
# {dict, list, tuple, ndarray, str, float, int, None}
d0 = dict(a=dict(a=0.1, b="fo", c=1), b=[1, "b"], c=(), d=np.ones(3), e=None)
d0[1] = None
d0[2.0] = b"123"
d1 = deepcopy(d0)
assert len(object_diff(d0, d1)) == 0
assert len(object_diff(d1, d0)) == 0
assert object_hash(d0) == object_hash(d1)
# change values slightly
d1["data"] = np.ones(3, int)
d1["d"][0] = 0
assert object_hash(d0) != object_hash(d1)
d1 = deepcopy(d0)
assert object_hash(d0) == object_hash(d1)
d1["a"]["a"] = 0.11
assert len(object_diff(d0, d1)) > 0
assert len(object_diff(d1, d0)) > 0
assert object_hash(d0) != object_hash(d1)
d1 = deepcopy(d0)
assert object_hash(d0) == object_hash(d1)
d1["a"]["d"] = 0 # non-existent key
assert len(object_diff(d0, d1)) > 0
assert len(object_diff(d1, d0)) > 0
assert object_hash(d0) != object_hash(d1)
d1 = deepcopy(d0)
assert object_hash(d0) == object_hash(d1)
d1["b"].append(0) # different-length lists
assert len(object_diff(d0, d1)) > 0
assert len(object_diff(d1, d0)) > 0
assert object_hash(d0) != object_hash(d1)
d1 = deepcopy(d0)
assert object_hash(d0) == object_hash(d1)
d1["e"] = "foo" # non-None
assert len(object_diff(d0, d1)) > 0
assert len(object_diff(d1, d0)) > 0
assert object_hash(d0) != object_hash(d1)
d1 = deepcopy(d0)
d2 = deepcopy(d0)
d1["e"] = StringIO()
d2["e"] = StringIO()
d2["e"].write("foo")
assert len(object_diff(d0, d1)) > 0
assert len(object_diff(d1, d0)) > 0
d1 = deepcopy(d0)
d1[1] = 2
assert len(object_diff(d0, d1)) > 0
assert len(object_diff(d1, d0)) > 0
assert object_hash(d0) != object_hash(d1)
# generators (and other types) not supported
d1 = deepcopy(d0)
d2 = deepcopy(d0)
d1[1] = (x for x in d0)
d2[1] = (x for x in d0)
pytest.raises(RuntimeError, object_diff, d1, d2)
pytest.raises(RuntimeError, object_hash, d1)
x = _eye_array(2, format="csc")
y = _eye_array(2, format="csr")
assert "type mismatch" in object_diff(x, y)
y = _eye_array(2, format="csc")
assert len(object_diff(x, y)) == 0
y[1, 1] = 2
assert "elements" in object_diff(x, y)
y = _eye_array(3, format="csc")
assert "shape" in object_diff(x, y)
y = 0
assert "type mismatch" in object_diff(x, y)
# smoke test for gh-4796
assert object_hash(np.int64(1)) != 0
assert object_hash(np.bool_(True)) != 0
@pytest.mark.parametrize("n_components", (None, 0.9999, 8, "mle"))
@pytest.mark.parametrize("whiten", (True, False))
def test_pca(n_components, whiten):
"""Test PCA equivalence."""
pytest.importorskip("sklearn")
from sklearn.decomposition import PCA
n_samples, n_dim = 1000, 10
X = np.random.RandomState(0).randn(n_samples, n_dim)
X[:, -1] = np.mean(X[:, :-1], axis=-1) # true X dim is ndim - 1
X_orig = X.copy()
pca_skl = PCA(n_components, whiten=whiten, svd_solver="full")
pca_mne = _PCA(n_components, whiten=whiten)
X_skl = pca_skl.fit_transform(X)
assert_array_equal(X, X_orig)
X_mne = pca_mne.fit_transform(X)
assert_array_equal(X, X_orig)
assert_allclose(X_skl, X_mne * np.sign(np.sum(X_skl * X_mne, axis=0)))
assert pca_mne.n_components_ == pca_skl.n_components_
for key in (
"mean_",
"components_",
"explained_variance_",
"explained_variance_ratio_",
):
val_skl, val_mne = getattr(pca_skl, key), getattr(pca_mne, key)
if key == "components_":
val_mne = val_mne * np.sign(
np.sum(val_skl * val_mne, axis=1, keepdims=True)
)
assert_allclose(val_skl, val_mne)
if isinstance(n_components, float):
assert pca_mne.n_components_ == n_dim - 1
elif isinstance(n_components, int):
assert pca_mne.n_components_ == n_components
elif n_components == "mle":
assert pca_mne.n_components_ == n_dim - 1
else:
assert n_components is None
assert pca_mne.n_components_ == n_dim
def test_array_equal_nan():
"""Test comparing arrays with NaNs."""
a = b = [1, np.nan, 0]
assert not np.array_equal(a, b) # this is the annoying behavior we avoid
assert _array_equal_nan(a, b)
b = [np.nan, 1, 0]
assert not _array_equal_nan(a, b)
a = b = [np.nan] * 2
assert _array_equal_nan(a, b)
def test_julian_conversions():
"""Test julian calendar conversions."""
# https://aa.usno.navy.mil/data/docs/JulianDate.php
# A.D. 1922 Jun 13 12:00:00.0 2423219.000000
# A.D. 2018 Oct 3 12:00:00.0 2458395.000000
jds = [2423219, 2458395, 2445701]
cals = [(1922, 6, 13), (2018, 10, 3), (1984, 1, 1)]
dds = [date(*c) for c in cals]
for dd, cal, jd in zip(dds, cals, jds):
assert dd == _julian_to_date(jd)
assert jd == _date_to_julian(dd)
def test_grand_average_empty_sequence():
"""Test if mne.grand_average handles an empty sequence correctly."""
with pytest.raises(ValueError, match="Please pass a list of Evoked"):
grand_average([])
def test_grand_average_len_1():
"""Test if mne.grand_average handles a sequence of length 1 correctly."""
# returns a list of length 1
evokeds = read_evokeds(ave_fname, condition=[0], proj=True)
with pytest.warns(RuntimeWarning, match="Only a single dataset"):
gave = grand_average(evokeds)
assert_allclose(gave.data, evokeds[0].data)
def test_reuse_cycle():
"""Test _ReuseCycle."""
vals = "abcde"
iterable = _ReuseCycle(vals)
assert "".join(next(iterable) for _ in range(2 * len(vals))) == vals + vals
# we're back to initial
assert "".join(next(iterable) for _ in range(2)) == "ab"
iterable.restore("a")
assert "".join(next(iterable) for _ in range(10)) == "acdeabcdea"
assert "".join(next(iterable) for _ in range(4)) == "bcde"
# we're back to initial
assert "".join(next(iterable) for _ in range(3)) == "abc"
iterable.restore("a")
iterable.restore("b")
iterable.restore("c")
assert "".join(next(iterable) for _ in range(5)) == "abcde"
# we're back to initial
assert "".join(next(iterable) for _ in range(3)) == "abc"
iterable.restore("a")
iterable.restore("c")
assert "".join(next(iterable) for _ in range(4)) == "acde"
assert "".join(next(iterable) for _ in range(5)) == "abcde"
# we're back to initial
assert "".join(next(iterable) for _ in range(3)) == "abc"
iterable.restore("c")
iterable.restore("a")
with pytest.warns(RuntimeWarning, match="Could not find"):
iterable.restore("a")
assert "".join(next(iterable) for _ in range(4)) == "acde"
assert "".join(next(iterable) for _ in range(5)) == "abcde"
@pytest.mark.parametrize("n", (0, 1, 10, 1000))
@pytest.mark.parametrize("d", (0.0001, 1, 2.5, 1000))
def test_arange_div(numba_conditional, n, d):
"""Test Numba arange_div."""
want = np.arange(n) / d
got = numerics._arange_div(n, d)
assert_allclose(got, want)
def test_custom_lru_cache():
"""Test our _custom_lru_cache implementation."""
n_calls = [0, 0]
start_size = len(_LRU_CACHES)
@_custom_lru_cache(2)
def my_fun(*args):
n_calls[0] += 1
return ", ".join(arg.__class__.__name__ for arg in args)
assert len(_LRU_CACHES) == start_size + 1
fun_hash = list(_LRU_CACHES)[-1]
assert _LRU_CACHE_MAXSIZES[fun_hash] == 2
@_custom_lru_cache(1)
def my_fun_2(*args):
n_calls[1] += 1
return ", ".join(arg.__class__.__name__ for arg in args)
assert len(_LRU_CACHES) == start_size + 2
fun_2_hash = list(_LRU_CACHES)[-1]
assert _LRU_CACHE_MAXSIZES[fun_2_hash] == 1
assert n_calls == [0, 0]
assert my_fun(1, 2, 3) == "int, int, int"
assert n_calls == [1, 0]
assert my_fun_2(1, 2, 3.0) == "int, int, float"
assert n_calls == [1, 1]
# repeated calls use cached version
assert my_fun(1, 2, 3) == "int, int, int"
assert n_calls == [1, 1]
assert my_fun_2(1, 2, 3.0) == "int, int, float"
assert n_calls == [1, 1]
assert len(_LRU_CACHES[fun_hash]) == 1
assert len(_LRU_CACHES[fun_2_hash]) == 1
assert my_fun(1, np.array([2]), 3) == "int, ndarray, int"
assert n_calls == [2, 1]
assert len(_LRU_CACHES[fun_hash]) == 2
assert my_fun_2(1, _eye_array(1, format="csc")) == "int, csc_array"
assert n_calls == [2, 2]
assert len(_LRU_CACHES[fun_2_hash]) == 1 # other got popped
# we could add support for this eventually, but don't bother for now
with pytest.raises(RuntimeError, match="Unsupported sparse type"):
my_fun_2(1, _eye_array(1, format="coo"))
assert n_calls == [2, 2] # never did any computation
def test_replace_md5(tmp_path):
"""Test _replace_md5."""
old = tmp_path / "test"
new = old.with_suffix(".new")
old.write_text("abcd")
new.write_text("abcde")
assert old.is_file()
assert new.is_file()
_replace_md5(str(new))
assert not new.is_file()
assert old.read_text() == "abcde"
new.write_text(old.read_text())
_replace_md5(str(new))
assert old.read_text() == "abcde"
assert not new.is_file()