# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_array_less,
assert_equal,
)
from scipy import fftpack
from mne import Epochs, make_fixed_length_events, read_events
from mne.io import read_raw_fif
from mne.time_frequency import AverageTFR, tfr_array_stockwell
from mne.time_frequency._stockwell import (
_check_input_st,
_precompute_st_windows,
_st,
_st_power_itc,
tfr_stockwell,
)
from mne.utils import _record_warnings
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
raw_ctf_fname = base_dir / "test_ctf_raw.fif"
def test_stockwell_ctf():
"""Test that Stockwell can be calculated on CTF data."""
raw = read_raw_fif(raw_ctf_fname)
raw.apply_gradient_compensation(3)
events = make_fixed_length_events(raw, duration=0.5)
evoked = Epochs(
raw, events, tmin=-0.2, tmax=0.3, decim=10, preload=True, verbose="error"
).average()
tfr_stockwell(evoked, verbose="error") # smoke test
def test_stockwell_check_input():
"""Test input checker for stockwell."""
# check for data size equal and unequal to a power of 2
for last_dim in (127, 128):
data = np.zeros((2, 10, last_dim))
with _record_warnings(): # n_fft sometimes
x_in, n_fft, zero_pad = _check_input_st(data, None)
assert_equal(x_in.shape, (2, 10, 128))
assert_equal(n_fft, 128)
assert_equal(zero_pad, 128 - last_dim)
def test_stockwell_st_no_zero_pad():
"""Test stockwell power itc."""
data = np.zeros((20, 128))
start_f = 1
stop_f = 10
sfreq = 30
width = 2
W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
_st_power_itc(data, 10, True, 0, 1, W)
def test_stockwell_core():
"""Test stockwell transform."""
# adapted from
# http://vcs.ynic.york.ac.uk/docs/naf/intro/concepts/timefreq.html
sfreq = 1000.0 # make things easy to understand
dur = 0.5
onset, offset = 0.175, 0.275
n_samp = int(sfreq * dur)
t = np.arange(n_samp) / sfreq # make an array for time
pulse_freq = 15.0
pulse = np.cos(2.0 * np.pi * pulse_freq * t)
pulse[0 : int(onset * sfreq)] = 0.0 # Zero before our desired pulse
pulse[int(offset * sfreq) :] = 0.0 # and zero after our desired pulse
width = 0.5
freqs = fftpack.fftfreq(len(pulse), 1.0 / sfreq)
fmin, fmax = 1.0, 100.0
start_f, stop_f = (np.abs(freqs - f).argmin() for f in (fmin, fmax))
W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width)
st_pulse = _st(pulse, start_f, W)
st_pulse = np.abs(st_pulse) ** 2
assert_equal(st_pulse.shape[-1], len(pulse))
st_max_freq = freqs[st_pulse.max(axis=1).argmax(axis=0)] # max freq
assert_allclose(st_max_freq, pulse_freq, atol=1.0)
assert onset < t[st_pulse.max(axis=0).argmax(axis=0)] < offset
# test inversion to FFT, by averaging local spectra, see eq. 5 in
# Moukadem, A., Bouguila, Z., Ould Abdeslam, D. and Alain Dieterlen.
# "Stockwell transform optimization applied on the detection of split in
# heart sounds."
width = 1.0
start_f, stop_f = 0, len(pulse)
W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width)
y = _st(pulse, start_f, W)
# invert stockwell
y_inv = fftpack.ifft(np.sum(y, axis=1)).real
assert_array_almost_equal(pulse, y_inv)
def test_stockwell_api():
"""Test stockwell functions."""
raw = read_raw_fif(raw_fname)
event_id, tmin, tmax = 1, -0.2, 0.5
event_name = base_dir / "test-eve.fif"
events = read_events(event_name)
epochs = Epochs(
raw,
events, # XXX pick 2 has epochs of zeros.
event_id,
tmin,
tmax,
picks=[0, 1, 3],
)
for fmin, fmax in [(None, 50), (5, 50), (5, None)]:
power, itc = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, return_itc=True)
if fmax is not None:
assert power.freqs.max() <= fmax
power_evoked = tfr_stockwell(
epochs.average(), fmin=fmin, fmax=fmax, return_itc=False
)
# for multitaper these don't necessarily match, but they seem to
# for stockwell... if this fails, this maybe could be changed
# just to check the shape
assert_array_almost_equal(power_evoked.data, power.data)
assert isinstance(power, AverageTFR)
assert isinstance(itc, AverageTFR)
assert_equal(power.data.shape, itc.data.shape)
assert itc.data.min() >= 0.0
assert itc.data.max() <= 1.0
assert np.log(power.data.max()) * 20 <= 0.0
assert np.log(power.data.max()) * 20 <= 0.0
with pytest.raises(TypeError, match="ndarray"):
tfr_array_stockwell("foo", 1000.0)
data = np.random.RandomState(0).randn(1, 1024)
with pytest.raises(ValueError, match="3D with shape"):
tfr_array_stockwell(data, 1000.0)
data = data[np.newaxis]
power, itc, freqs = tfr_array_stockwell(data, 1000.0, return_itc=True)
assert_allclose(itc, np.ones_like(itc))
assert power.shape == (1, len(freqs), data.shape[-1])
assert_array_less(0, power)