# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from copy import deepcopy
import numpy as np
from scipy.fft import fft, fftfreq, ifft
from .._fiff.pick import _pick_data_channels, pick_info
from ..parallel import parallel_func
from ..utils import _validate_type, legacy, logger, verbose
from .tfr import AverageTFRArray, _ensure_slice, _get_data
def _check_input_st(x_in, n_fft):
"""Aux function."""
# flatten to 2 D and memorize original shape
n_times = x_in.shape[-1]
def _is_power_of_two(n):
return not (n > 0 and (n & (n - 1)))
if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft):
# Compute next power of 2
n_fft = 2 ** int(np.ceil(np.log2(n_times)))
elif n_fft < n_times:
raise ValueError(
f"n_fft cannot be smaller than signal size. Got {n_fft} < {n_times}."
)
if n_times < n_fft:
logger.info(
f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" ({n_fft}). '
"Applying zero padding."
)
zero_pad = n_fft - n_times
pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype)
x_in = np.concatenate((x_in, pad_array), axis=-1)
else:
zero_pad = 0
return x_in, n_fft, zero_pad
def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width):
"""Precompute stockwell Gaussian windows (in the freq domain)."""
tw = fftfreq(n_samp, 1.0 / sfreq) / n_samp
tw = np.r_[tw[:1], tw[1:][::-1]]
k = width # 1 for classical stowckwell transform
f_range = np.arange(start_f, stop_f, 1)
windows = np.empty((len(f_range), len(tw)), dtype=np.complex128)
for i_f, f in enumerate(f_range):
if f == 0.0:
window = np.ones(len(tw))
else:
window = (f / (np.sqrt(2.0 * np.pi) * k)) * np.exp(
-0.5 * (1.0 / k**2.0) * (f**2.0) * tw**2.0
)
window /= window.sum() # normalisation
windows[i_f] = fft(window)
return windows
def _st(x, start_f, windows):
"""Compute ST based on Ali Moukadem MATLAB code (used in tests)."""
from scipy.fft import fft, ifft
n_samp = x.shape[-1]
ST = np.empty(x.shape[:-1] + (len(windows), n_samp), dtype=np.complex128)
# do the work
Fx = fft(x)
XF = np.concatenate([Fx, Fx], axis=-1)
for i_f, window in enumerate(windows):
f = start_f + i_f
ST[..., i_f, :] = ifft(XF[..., f : f + n_samp] * window)
return ST
def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
"""Aux function."""
decim = _ensure_slice(decim)
n_samp = x.shape[-1]
decim_indices = decim.indices(n_samp - zero_pad)
n_out = len(range(*decim_indices))
psd = np.empty((len(W), n_out))
itc = np.empty_like(psd) if compute_itc else None
X = fft(x)
XX = np.concatenate([X, X], axis=-1)
for i_f, window in enumerate(W):
f = start_f + i_f
ST = ifft(XX[:, f : f + n_samp] * window)
TFR = ST[:, slice(*decim_indices)]
TFR_abs = np.abs(TFR)
TFR_abs[TFR_abs == 0] = 1.0
if compute_itc:
TFR /= TFR_abs
itc[i_f] = np.abs(np.mean(TFR, axis=0))
TFR_abs *= TFR_abs
psd[i_f] = np.mean(TFR_abs, axis=0)
return psd, itc
def _compute_freqs_st(fmin, fmax, n_fft, sfreq):
from scipy.fft import fftfreq
freqs = fftfreq(n_fft, 1.0 / sfreq)
if fmin is None:
fmin = freqs[freqs > 0][0]
if fmax is None:
fmax = freqs.max()
start_f = np.abs(freqs - fmin).argmin()
stop_f = np.abs(freqs - fmax).argmin()
freqs = freqs[start_f:stop_f]
return start_f, stop_f, freqs
@verbose
def tfr_array_stockwell(
data,
sfreq,
fmin=None,
fmax=None,
n_fft=None,
width=1.0,
decim=1,
return_itc=False,
n_jobs=None,
*,
verbose=None,
):
"""Compute power and intertrial coherence using Stockwell (S) transform.
Same computation as `~mne.time_frequency.tfr_stockwell`, but operates on
:class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects.
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
for more information.
Parameters
----------
data : ndarray, shape (n_epochs, n_channels, n_times)
The signal to transform.
sfreq : float
The sampling frequency.
fmin : None, float
The minimum frequency to include. If None defaults to the minimum fft
frequency greater than zero.
fmax : None, float
The maximum frequency to include. If None defaults to the maximum fft.
n_fft : int | None
The length of the windows used for FFT. If None, it defaults to the
next power of 2 larger than the signal length.
width : float
The width of the Gaussian window. If < 1, increased temporal
resolution, if > 1, increased frequency resolution. Defaults to 1.
(classical S-Transform).
%(decim_tfr)s
return_itc : bool
Return intertrial coherence (ITC) as well as averaged power.
%(n_jobs)s
%(verbose)s
Returns
-------
st_power : ndarray
The multitaper power of the Stockwell transformed data.
The last two dimensions are frequency and time.
itc : ndarray
The intertrial coherence. Only returned if return_itc is True.
freqs : ndarray
The frequencies.
See Also
--------
mne.time_frequency.tfr_stockwell
mne.time_frequency.tfr_multitaper
mne.time_frequency.tfr_array_multitaper
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_array_morlet
References
----------
.. footbibliography::
"""
_validate_type(data, np.ndarray, "data")
if data.ndim != 3:
raise ValueError(
"data must be 3D with shape (n_epochs, n_channels, n_times), "
f"got {data.shape}"
)
decim = _ensure_slice(decim)
_, n_channels, n_out = data[..., decim].shape
data, n_fft_, zero_pad = _check_input_st(data, n_fft)
start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq)
W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
n_freq = stop_f - start_f
psd = np.empty((n_channels, n_freq, n_out))
itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None
parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose)
tfrs = parallel(
my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W)
for c in range(n_channels)
)
for c, (this_psd, this_itc) in enumerate(iter(tfrs)):
psd[c] = this_psd
if this_itc is not None:
itc[c] = this_itc
return psd, itc, freqs
@legacy(alt='.compute_tfr(method="stockwell", freqs="auto")')
@verbose
def tfr_stockwell(
inst,
fmin=None,
fmax=None,
n_fft=None,
width=1.0,
decim=1,
return_itc=False,
n_jobs=None,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using Stockwell Transform.
Same computation as `~mne.time_frequency.tfr_array_stockwell`, but operates
on `~mne.Epochs` objects instead of :class:`NumPy arrays <numpy.ndarray>`.
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
for more information.
Parameters
----------
inst : Epochs | Evoked
The epochs or evoked object.
fmin : None, float
The minimum frequency to include. If None defaults to the minimum fft
frequency greater than zero.
fmax : None, float
The maximum frequency to include. If None defaults to the maximum fft.
n_fft : int | None
The length of the windows used for FFT. If None, it defaults to the
next power of 2 larger than the signal length.
width : float
The width of the Gaussian window. If < 1, increased temporal
resolution, if > 1, increased frequency resolution. Defaults to 1.
(classical S-Transform).
decim : int
The decimation factor on the time axis. To reduce memory usage.
return_itc : bool
Return intertrial coherence (ITC) as well as averaged power.
n_jobs : int
The number of jobs to run in parallel (over channels).
%(verbose)s
Returns
-------
power : AverageTFR
The averaged power.
itc : AverageTFR
The intertrial coherence. Only returned if return_itc is True.
See Also
--------
mne.time_frequency.tfr_array_stockwell
mne.time_frequency.tfr_multitaper
mne.time_frequency.tfr_array_multitaper
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_array_morlet
Notes
-----
.. versionadded:: 0.9.0
References
----------
.. footbibliography::
"""
# verbose dec is used b/c subfunctions are verbose
data = _get_data(inst, return_itc)
picks = _pick_data_channels(inst.info)
info = pick_info(inst.info, picks)
data = data[:, picks, :]
decim = _ensure_slice(decim)
power, itc, freqs = tfr_array_stockwell(
data,
sfreq=info["sfreq"],
fmin=fmin,
fmax=fmax,
n_fft=n_fft,
width=width,
decim=decim,
return_itc=return_itc,
n_jobs=n_jobs,
)
times = inst.times[decim].copy()
nave = len(data)
out = AverageTFRArray(
info=info,
data=power,
times=times,
freqs=freqs,
nave=nave,
method="stockwell-power",
)
if return_itc:
out = (
out,
AverageTFRArray(
info=deepcopy(info),
data=itc,
times=times.copy(),
freqs=freqs.copy(),
nave=nave,
method="stockwell-itc",
),
)
return out