[074d3d]: / mne / time_frequency / tests / test_psd.py

Download this file

229 lines (204 with data), 9.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
from scipy.signal import welch
from mne.time_frequency import psd_array_multitaper, psd_array_welch
from mne.time_frequency.multitaper import _psd_from_mt
from mne.time_frequency.psd import _median_biases
from mne.utils import catch_logging
def test_psd_nan():
"""Test handling of NaN in psd_array_welch."""
n_samples, n_fft, n_overlap = 2048, 1024, 512
x = np.random.RandomState(0).randn(1, n_samples)
psds, freqs = psd_array_welch(
x[:, : n_fft + n_overlap], float(n_fft), n_fft=n_fft, n_overlap=n_overlap
)
x[:, n_fft + n_overlap :] = np.nan # what Raw.get_data() will give us
psds_2, freqs_2 = psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
assert_allclose(freqs, freqs_2)
assert_allclose(psds, psds_2)
# 1-d
psds_2, freqs_2 = psd_array_welch(
x[0], float(n_fft), n_fft=n_fft, n_overlap=n_overlap
)
assert_allclose(freqs, freqs_2)
assert_allclose(psds[0], psds_2)
# defaults
with catch_logging() as log:
psd_array_welch(x, float(n_fft), verbose="debug")
log = log.getvalue()
assert "using 256-point FFT on 256 samples with 0 overlap" in log
assert "hamming window" in log
def test_bad_annot_handling():
"""Make sure results equivalent with/without Annotations."""
n_per_seg = 256
n_chan = 3
n_times = 5 * n_per_seg
x = np.random.default_rng(seed=42).standard_normal(size=(n_chan, n_times))
want = psd_array_welch(x, sfreq=100)
# now simulate an annotation that breaks up the array into unequal spans. Using
# `n_per_seg` as the cut point is unrealistic/idealized, but it allows us to test
# whether we get results ~identical to `want` (which we should in this edge case)
x2 = np.concatenate(
(x[..., :n_per_seg], np.full((n_chan, 1), np.nan), x[..., n_per_seg:]), axis=-1
)
got = psd_array_welch(x2, sfreq=100)
# freqs should be identical
np.testing.assert_array_equal(got[1], want[1])
# powers should be very very close
np.testing.assert_allclose(got[0], want[0], rtol=1e-15, atol=0)
def _make_psd_data():
"""Make noise data with sinusoids in 2 out of 7 channels."""
rng = np.random.default_rng(0)
n_chan, n_times, sfreq = 7, 8000, 1000
data = 0.1 * rng.random((n_chan, n_times))
times = np.arange(n_times) / sfreq
sinusoid_freqs = [8.0, 50.0]
chs_with_sinusoids = [0, 1]
for ix, freq in zip(chs_with_sinusoids, sinusoid_freqs):
data[ix, :] += 2 * np.sin(np.pi * 2.0 * freq * times)
return data, sfreq, sinusoid_freqs
@pytest.mark.parametrize(
"psd_func, psd_kwargs",
[
(psd_array_welch, dict(n_fft=128, window="hann")),
(psd_array_multitaper, dict(low_bias=True)),
],
)
def test_psd(psd_func, psd_kwargs):
"""Tests the welch and multitaper PSD."""
data, sfreq, sinusoid_freqs = _make_psd_data()
# prepare kwargs
psd_kwargs.update(dict(fmin=2, fmax=70, verbose="debug"))
# compute PSD and test basic conformity
with catch_logging() as log:
psds, freqs = psd_func(data, sfreq, **psd_kwargs)
if psd_func is psd_array_welch:
log = log.getvalue()
n_fft = psd_kwargs["n_fft"]
assert f"{n_fft}-point FFT on {n_fft} samples with 0 overl" in log
assert "hann window" in log
assert psds.shape == (data.shape[0], len(freqs))
assert np.sum(freqs < 0) == 0
assert np.sum(psds < 0) == 0
# Is power found where it should be?
ixs_max = np.argmax(psds, axis=1)
for ixmax, ifreq in zip(ixs_max, sinusoid_freqs):
# Find nearest frequency to the "true" freq
ixtrue = np.argmin(np.abs(ifreq - freqs))
assert np.abs(ixmax - ixtrue) < 2
def test_psd_array_welch_nperseg_kwarg():
"""Test n_per_seg and padding in psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
kwargs = dict(fmin=2, fmax=70, n_per_seg=128)
# test n_per_seg in psd_array_welch (and padding)
psds1, freqs1 = psd_array_welch(data, sfreq, n_fft=128, **kwargs)
psds2, freqs2 = psd_array_welch(data, sfreq, n_fft=256, **kwargs)
assert len(freqs1) == np.floor(len(freqs2) / 2.0)
assert psds1.shape[-1] == np.floor(psds2.shape[-1] / 2.0)
# test bad n_fft
with pytest.raises(ValueError, match="n_fft is not allowed to be > n_tim"):
kwargs.update(n_per_seg=None)
bad_n_fft = int(data.shape[-1] * 1.1)
psd_array_welch(data, sfreq, n_fft=bad_n_fft, **kwargs)
# test bad n_overlap
with pytest.raises(ValueError, match="n_overlap cannot be greater"):
kwargs.update(n_per_seg=64)
psd_array_welch(data, sfreq, n_fft=128, n_overlap=90, **kwargs)
# test bad fmin/fmax
with pytest.raises(ValueError, match="No frequencies found"):
psd_array_welch(data, sfreq, fmin=10, fmax=1)
def test_complex_multitaper():
"""Test complex-valued multitaper output."""
data, sfreq, _ = _make_psd_data()
psd_complex, freq_complex, weights = psd_array_multitaper(
data[:4, :500], sfreq, output="complex"
)
psd, freq = psd_array_multitaper(data[:4, :500], sfreq, output="power")
assert_array_equal(freq_complex, freq)
assert psd_complex.ndim == 3 # channels x tapers x freqs
psd_from_complex = _psd_from_mt(psd_complex, weights)
assert_allclose(psd_from_complex, psd)
# Copied from SciPy
def _median_bias(n):
ii_2 = 2 * np.arange(1.0, (n - 1) // 2 + 1)
return 1 + np.sum(1.0 / (ii_2 + 1) - 1.0 / ii_2)
@pytest.mark.parametrize("crop", (False, True))
def test_psd_array_welch_average_kwarg(crop):
"""Test `average` kwarg of psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
n_per_seg = 32
kwargs = dict(fmin=0, fmax=np.inf, n_fft=64, n_per_seg=n_per_seg, n_overlap=0)
# optionally crop data by n_per_seg so that we are sure to test both an
# odd number and an even number of estimates (for median bias)
if crop:
data = data[..., :-n_per_seg]
# run with average=mean/median/None
psds_mean, freqs_mean = psd_array_welch(data, sfreq, average="mean", **kwargs)
psds_median, freqs_median = psd_array_welch(data, sfreq, average="median", **kwargs)
psds_unagg, freqs_unagg = psd_array_welch(data, sfreq, average=None, **kwargs)
# Frequencies should be equal across all "average" types, as we feed in
# the exact same data.
assert_array_equal(freqs_mean, freqs_median)
assert_array_equal(freqs_mean, freqs_unagg)
# For `average=None`, the last dimension contains the un-aggregated
# segments.
assert psds_mean.shape == psds_median.shape
assert psds_mean.shape == psds_unagg.shape[:-1]
assert_array_equal(psds_mean, psds_unagg.mean(axis=-1))
# Compare with manual median calculation (_median_bias copied from SciPy)
bias = _median_bias(psds_unagg.shape[-1])
assert_allclose(psds_median, np.median(psds_unagg, axis=-1) / bias)
# check shape of unagg
n_chan, n_times = data.shape
n_freq = len(freqs_unagg)
n_segs = np.ceil(n_times / n_per_seg).astype(int)
assert n_segs % 2 == (1 if crop else 0)
assert psds_unagg.shape == (n_chan, n_freq, n_segs)
@pytest.mark.parametrize("n", (2, 3, 5, 8, 12, 13, 14, 15))
def test_median_biases(n):
"""Test vectorization of median_biases."""
want_biases = np.concatenate(
([1.0, 1.0], [_median_bias(ii) for ii in range(2, n + 1)])
)
got_biases = _median_biases(n)
assert_allclose(want_biases, got_biases)
assert_allclose(got_biases[n], _median_bias(n))
assert_allclose(got_biases[:3], 1.0)
@pytest.mark.slowtest
def test_compares_psd():
"""Test PSD estimation on raw for plt.psd and scipy.signal.welch."""
data, sfreq, _ = _make_psd_data()
n_fft = 2048
fmin, fmax = 2, 70
# Compute PSD with psd_array_welch
psds_mne, freqs_mne = psd_array_welch(
data, sfreq, fmin=fmin, fmax=fmax, n_fft=n_fft
)
# Compute psds with scipy.signal.welch
freqs_scipy, psds_scipy = welch(
data, fs=sfreq, nperseg=n_fft, noverlap=0, window="hamming"
)
# restrict to the relevant frequencies
mask = (freqs_scipy >= fmin) & (freqs_scipy <= fmax)
freqs_scipy = freqs_scipy[mask]
psds_scipy = psds_scipy[:, mask]
# make sure they match
assert_array_almost_equal(psds_mne, psds_scipy)
assert_array_equal(freqs_mne, freqs_scipy)
assert psds_mne.shape == (data.shape[0], len(freqs_mne))
assert psds_scipy.shape == (data.shape[0], len(freqs_scipy))
assert np.sum(freqs_mne < 0) == 0
assert np.sum(freqs_scipy < 0) == 0
assert np.sum(psds_mne < 0) == 0
assert np.sum(psds_scipy < 0) == 0
def test_psd_array_welch_n_jobs():
"""Test that n_jobs works even with more jobs than channels."""
data = np.zeros((1, 2048))
psd_array_welch(data, 1024, n_jobs=1)
psd_array_welch(data, 1024, n_jobs=2)