[074d3d]: / mne / time_frequency / tests / test_stockwell.py

Download this file

153 lines (132 with data), 5.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_array_less,
assert_equal,
)
from scipy import fftpack
from mne import Epochs, make_fixed_length_events, read_events
from mne.io import read_raw_fif
from mne.time_frequency import AverageTFR, tfr_array_stockwell
from mne.time_frequency._stockwell import (
_check_input_st,
_precompute_st_windows,
_st,
_st_power_itc,
tfr_stockwell,
)
from mne.utils import _record_warnings
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
raw_ctf_fname = base_dir / "test_ctf_raw.fif"
def test_stockwell_ctf():
"""Test that Stockwell can be calculated on CTF data."""
raw = read_raw_fif(raw_ctf_fname)
raw.apply_gradient_compensation(3)
events = make_fixed_length_events(raw, duration=0.5)
evoked = Epochs(
raw, events, tmin=-0.2, tmax=0.3, decim=10, preload=True, verbose="error"
).average()
tfr_stockwell(evoked, verbose="error") # smoke test
def test_stockwell_check_input():
"""Test input checker for stockwell."""
# check for data size equal and unequal to a power of 2
for last_dim in (127, 128):
data = np.zeros((2, 10, last_dim))
with _record_warnings(): # n_fft sometimes
x_in, n_fft, zero_pad = _check_input_st(data, None)
assert_equal(x_in.shape, (2, 10, 128))
assert_equal(n_fft, 128)
assert_equal(zero_pad, 128 - last_dim)
def test_stockwell_st_no_zero_pad():
"""Test stockwell power itc."""
data = np.zeros((20, 128))
start_f = 1
stop_f = 10
sfreq = 30
width = 2
W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
_st_power_itc(data, 10, True, 0, 1, W)
def test_stockwell_core():
"""Test stockwell transform."""
# adapted from
# http://vcs.ynic.york.ac.uk/docs/naf/intro/concepts/timefreq.html
sfreq = 1000.0 # make things easy to understand
dur = 0.5
onset, offset = 0.175, 0.275
n_samp = int(sfreq * dur)
t = np.arange(n_samp) / sfreq # make an array for time
pulse_freq = 15.0
pulse = np.cos(2.0 * np.pi * pulse_freq * t)
pulse[0 : int(onset * sfreq)] = 0.0 # Zero before our desired pulse
pulse[int(offset * sfreq) :] = 0.0 # and zero after our desired pulse
width = 0.5
freqs = fftpack.fftfreq(len(pulse), 1.0 / sfreq)
fmin, fmax = 1.0, 100.0
start_f, stop_f = (np.abs(freqs - f).argmin() for f in (fmin, fmax))
W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width)
st_pulse = _st(pulse, start_f, W)
st_pulse = np.abs(st_pulse) ** 2
assert_equal(st_pulse.shape[-1], len(pulse))
st_max_freq = freqs[st_pulse.max(axis=1).argmax(axis=0)] # max freq
assert_allclose(st_max_freq, pulse_freq, atol=1.0)
assert onset < t[st_pulse.max(axis=0).argmax(axis=0)] < offset
# test inversion to FFT, by averaging local spectra, see eq. 5 in
# Moukadem, A., Bouguila, Z., Ould Abdeslam, D. and Alain Dieterlen.
# "Stockwell transform optimization applied on the detection of split in
# heart sounds."
width = 1.0
start_f, stop_f = 0, len(pulse)
W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width)
y = _st(pulse, start_f, W)
# invert stockwell
y_inv = fftpack.ifft(np.sum(y, axis=1)).real
assert_array_almost_equal(pulse, y_inv)
def test_stockwell_api():
"""Test stockwell functions."""
raw = read_raw_fif(raw_fname)
event_id, tmin, tmax = 1, -0.2, 0.5
event_name = base_dir / "test-eve.fif"
events = read_events(event_name)
epochs = Epochs(
raw,
events, # XXX pick 2 has epochs of zeros.
event_id,
tmin,
tmax,
picks=[0, 1, 3],
)
for fmin, fmax in [(None, 50), (5, 50), (5, None)]:
power, itc = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, return_itc=True)
if fmax is not None:
assert power.freqs.max() <= fmax
power_evoked = tfr_stockwell(
epochs.average(), fmin=fmin, fmax=fmax, return_itc=False
)
# for multitaper these don't necessarily match, but they seem to
# for stockwell... if this fails, this maybe could be changed
# just to check the shape
assert_array_almost_equal(power_evoked.data, power.data)
assert isinstance(power, AverageTFR)
assert isinstance(itc, AverageTFR)
assert_equal(power.data.shape, itc.data.shape)
assert itc.data.min() >= 0.0
assert itc.data.max() <= 1.0
assert np.log(power.data.max()) * 20 <= 0.0
assert np.log(power.data.max()) * 20 <= 0.0
with pytest.raises(TypeError, match="ndarray"):
tfr_array_stockwell("foo", 1000.0)
data = np.random.RandomState(0).randn(1, 1024)
with pytest.raises(ValueError, match="3D with shape"):
tfr_array_stockwell(data, 1000.0)
data = data[np.newaxis]
power, itc, freqs = tfr_array_stockwell(data, 1000.0, return_itc=True)
assert_allclose(itc, np.ones_like(itc))
assert power.shape == (1, len(freqs), data.shape[-1])
assert_array_less(0, power)