# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
from scipy.signal import welch
from mne.time_frequency import psd_array_multitaper, psd_array_welch
from mne.time_frequency.multitaper import _psd_from_mt
from mne.time_frequency.psd import _median_biases
from mne.utils import catch_logging
def test_psd_nan():
"""Test handling of NaN in psd_array_welch."""
n_samples, n_fft, n_overlap = 2048, 1024, 512
x = np.random.RandomState(0).randn(1, n_samples)
psds, freqs = psd_array_welch(
x[:, : n_fft + n_overlap], float(n_fft), n_fft=n_fft, n_overlap=n_overlap
)
x[:, n_fft + n_overlap :] = np.nan # what Raw.get_data() will give us
psds_2, freqs_2 = psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
assert_allclose(freqs, freqs_2)
assert_allclose(psds, psds_2)
# 1-d
psds_2, freqs_2 = psd_array_welch(
x[0], float(n_fft), n_fft=n_fft, n_overlap=n_overlap
)
assert_allclose(freqs, freqs_2)
assert_allclose(psds[0], psds_2)
# defaults
with catch_logging() as log:
psd_array_welch(x, float(n_fft), verbose="debug")
log = log.getvalue()
assert "using 256-point FFT on 256 samples with 0 overlap" in log
assert "hamming window" in log
def test_bad_annot_handling():
"""Make sure results equivalent with/without Annotations."""
n_per_seg = 256
n_chan = 3
n_times = 5 * n_per_seg
x = np.random.default_rng(seed=42).standard_normal(size=(n_chan, n_times))
want = psd_array_welch(x, sfreq=100)
# now simulate an annotation that breaks up the array into unequal spans. Using
# `n_per_seg` as the cut point is unrealistic/idealized, but it allows us to test
# whether we get results ~identical to `want` (which we should in this edge case)
x2 = np.concatenate(
(x[..., :n_per_seg], np.full((n_chan, 1), np.nan), x[..., n_per_seg:]), axis=-1
)
got = psd_array_welch(x2, sfreq=100)
# freqs should be identical
np.testing.assert_array_equal(got[1], want[1])
# powers should be very very close
np.testing.assert_allclose(got[0], want[0], rtol=1e-15, atol=0)
def _make_psd_data():
"""Make noise data with sinusoids in 2 out of 7 channels."""
rng = np.random.default_rng(0)
n_chan, n_times, sfreq = 7, 8000, 1000
data = 0.1 * rng.random((n_chan, n_times))
times = np.arange(n_times) / sfreq
sinusoid_freqs = [8.0, 50.0]
chs_with_sinusoids = [0, 1]
for ix, freq in zip(chs_with_sinusoids, sinusoid_freqs):
data[ix, :] += 2 * np.sin(np.pi * 2.0 * freq * times)
return data, sfreq, sinusoid_freqs
@pytest.mark.parametrize(
"psd_func, psd_kwargs",
[
(psd_array_welch, dict(n_fft=128, window="hann")),
(psd_array_multitaper, dict(low_bias=True)),
],
)
def test_psd(psd_func, psd_kwargs):
"""Tests the welch and multitaper PSD."""
data, sfreq, sinusoid_freqs = _make_psd_data()
# prepare kwargs
psd_kwargs.update(dict(fmin=2, fmax=70, verbose="debug"))
# compute PSD and test basic conformity
with catch_logging() as log:
psds, freqs = psd_func(data, sfreq, **psd_kwargs)
if psd_func is psd_array_welch:
log = log.getvalue()
n_fft = psd_kwargs["n_fft"]
assert f"{n_fft}-point FFT on {n_fft} samples with 0 overl" in log
assert "hann window" in log
assert psds.shape == (data.shape[0], len(freqs))
assert np.sum(freqs < 0) == 0
assert np.sum(psds < 0) == 0
# Is power found where it should be?
ixs_max = np.argmax(psds, axis=1)
for ixmax, ifreq in zip(ixs_max, sinusoid_freqs):
# Find nearest frequency to the "true" freq
ixtrue = np.argmin(np.abs(ifreq - freqs))
assert np.abs(ixmax - ixtrue) < 2
def test_psd_array_welch_nperseg_kwarg():
"""Test n_per_seg and padding in psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
kwargs = dict(fmin=2, fmax=70, n_per_seg=128)
# test n_per_seg in psd_array_welch (and padding)
psds1, freqs1 = psd_array_welch(data, sfreq, n_fft=128, **kwargs)
psds2, freqs2 = psd_array_welch(data, sfreq, n_fft=256, **kwargs)
assert len(freqs1) == np.floor(len(freqs2) / 2.0)
assert psds1.shape[-1] == np.floor(psds2.shape[-1] / 2.0)
# test bad n_fft
with pytest.raises(ValueError, match="n_fft is not allowed to be > n_tim"):
kwargs.update(n_per_seg=None)
bad_n_fft = int(data.shape[-1] * 1.1)
psd_array_welch(data, sfreq, n_fft=bad_n_fft, **kwargs)
# test bad n_overlap
with pytest.raises(ValueError, match="n_overlap cannot be greater"):
kwargs.update(n_per_seg=64)
psd_array_welch(data, sfreq, n_fft=128, n_overlap=90, **kwargs)
# test bad fmin/fmax
with pytest.raises(ValueError, match="No frequencies found"):
psd_array_welch(data, sfreq, fmin=10, fmax=1)
def test_complex_multitaper():
"""Test complex-valued multitaper output."""
data, sfreq, _ = _make_psd_data()
psd_complex, freq_complex, weights = psd_array_multitaper(
data[:4, :500], sfreq, output="complex"
)
psd, freq = psd_array_multitaper(data[:4, :500], sfreq, output="power")
assert_array_equal(freq_complex, freq)
assert psd_complex.ndim == 3 # channels x tapers x freqs
psd_from_complex = _psd_from_mt(psd_complex, weights)
assert_allclose(psd_from_complex, psd)
# Copied from SciPy
def _median_bias(n):
ii_2 = 2 * np.arange(1.0, (n - 1) // 2 + 1)
return 1 + np.sum(1.0 / (ii_2 + 1) - 1.0 / ii_2)
@pytest.mark.parametrize("crop", (False, True))
def test_psd_array_welch_average_kwarg(crop):
"""Test `average` kwarg of psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
n_per_seg = 32
kwargs = dict(fmin=0, fmax=np.inf, n_fft=64, n_per_seg=n_per_seg, n_overlap=0)
# optionally crop data by n_per_seg so that we are sure to test both an
# odd number and an even number of estimates (for median bias)
if crop:
data = data[..., :-n_per_seg]
# run with average=mean/median/None
psds_mean, freqs_mean = psd_array_welch(data, sfreq, average="mean", **kwargs)
psds_median, freqs_median = psd_array_welch(data, sfreq, average="median", **kwargs)
psds_unagg, freqs_unagg = psd_array_welch(data, sfreq, average=None, **kwargs)
# Frequencies should be equal across all "average" types, as we feed in
# the exact same data.
assert_array_equal(freqs_mean, freqs_median)
assert_array_equal(freqs_mean, freqs_unagg)
# For `average=None`, the last dimension contains the un-aggregated
# segments.
assert psds_mean.shape == psds_median.shape
assert psds_mean.shape == psds_unagg.shape[:-1]
assert_array_equal(psds_mean, psds_unagg.mean(axis=-1))
# Compare with manual median calculation (_median_bias copied from SciPy)
bias = _median_bias(psds_unagg.shape[-1])
assert_allclose(psds_median, np.median(psds_unagg, axis=-1) / bias)
# check shape of unagg
n_chan, n_times = data.shape
n_freq = len(freqs_unagg)
n_segs = np.ceil(n_times / n_per_seg).astype(int)
assert n_segs % 2 == (1 if crop else 0)
assert psds_unagg.shape == (n_chan, n_freq, n_segs)
@pytest.mark.parametrize("n", (2, 3, 5, 8, 12, 13, 14, 15))
def test_median_biases(n):
"""Test vectorization of median_biases."""
want_biases = np.concatenate(
([1.0, 1.0], [_median_bias(ii) for ii in range(2, n + 1)])
)
got_biases = _median_biases(n)
assert_allclose(want_biases, got_biases)
assert_allclose(got_biases[n], _median_bias(n))
assert_allclose(got_biases[:3], 1.0)
@pytest.mark.slowtest
def test_compares_psd():
"""Test PSD estimation on raw for plt.psd and scipy.signal.welch."""
data, sfreq, _ = _make_psd_data()
n_fft = 2048
fmin, fmax = 2, 70
# Compute PSD with psd_array_welch
psds_mne, freqs_mne = psd_array_welch(
data, sfreq, fmin=fmin, fmax=fmax, n_fft=n_fft
)
# Compute psds with scipy.signal.welch
freqs_scipy, psds_scipy = welch(
data, fs=sfreq, nperseg=n_fft, noverlap=0, window="hamming"
)
# restrict to the relevant frequencies
mask = (freqs_scipy >= fmin) & (freqs_scipy <= fmax)
freqs_scipy = freqs_scipy[mask]
psds_scipy = psds_scipy[:, mask]
# make sure they match
assert_array_almost_equal(psds_mne, psds_scipy)
assert_array_equal(freqs_mne, freqs_scipy)
assert psds_mne.shape == (data.shape[0], len(freqs_mne))
assert psds_scipy.shape == (data.shape[0], len(freqs_scipy))
assert np.sum(freqs_mne < 0) == 0
assert np.sum(freqs_scipy < 0) == 0
assert np.sum(psds_mne < 0) == 0
assert np.sum(psds_scipy < 0) == 0
def test_psd_array_welch_n_jobs():
"""Test that n_jobs works even with more jobs than channels."""
data = np.zeros((1, 2048))
psd_array_welch(data, 1024, n_jobs=1)
psd_array_welch(data, 1024, n_jobs=2)