# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import warnings
from functools import partial
import numpy as np
from scipy.signal import spectrogram
from ..parallel import parallel_func
from ..utils import _check_option, _ensure_int, logger, verbose
from ..utils.numerics import _mask_to_onsets_offsets
# adapted from SciPy
# https://github.com/scipy/scipy/blob/f71e7fad717801c4476312fe1e23f2dfbb4c9d7f/scipy/signal/_spectral_py.py#L2019 # noqa: E501
def _median_biases(n):
# Compute the biases for 0 to max(n, 1) terms included in a median calc
biases = np.ones(n + 1)
# The original SciPy code is:
#
# def _median_bias(n):
# ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1)
# return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2)
#
# This is a sum over (n-1)//2 terms.
# The ii_2 terms here for different n are:
#
# n=0: [] # 0 terms
# n=1: [] # 0 terms
# n=2: [] # 0 terms
# n=3: [2] # 1 term
# n=4: [2] # 1 term
# n=5: [2, 4] # 2 terms
# n=6: [2, 4] # 2 terms
# ...
#
# We can get the terms for 0 through n using a cumulative summation and
# indexing:
if n >= 3:
ii_2 = 2 * np.arange(1, (n - 1) // 2 + 1)
sums = 1 + np.cumsum(1.0 / (ii_2 + 1) - 1.0 / ii_2)
idx = np.arange(2, n) // 2 - 1
biases[3:] = sums[idx]
return biases
def _decomp_aggregate_mask(epoch, func, average, freq_sl):
_, _, spect = func(epoch)
spect = spect[..., freq_sl, :]
# Do the averaging here (per epoch) to save memory
if average == "mean":
spect = np.nanmean(spect, axis=-1)
elif average == "median":
biases = _median_biases(spect.shape[-1])
idx = (~np.isnan(spect)).sum(-1)
spect = np.nanmedian(spect, axis=-1) / biases[idx]
return spect
def _spect_func(epoch, func, freq_sl, average, *, output="power"):
"""Aux function."""
# Decide if we should split this to save memory or not, since doing
# multiple calls will incur some performance overhead. Eventually we might
# want to write (really, go back to) our own spectrogram implementation
# that, if possible, averages after each transform, but this will incur
# a lot of overhead because of the many Python calls required.
kwargs = dict(func=func, average=average, freq_sl=freq_sl)
if epoch.nbytes > 10e6:
spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
else:
spect = _decomp_aggregate_mask(epoch, **kwargs)
return spect
def _check_nfft(n, n_fft, n_per_seg, n_overlap):
"""Ensure n_fft, n_per_seg and n_overlap make sense."""
if n_per_seg is None and n_fft > n:
raise ValueError(
"If n_per_seg is None n_fft is not allowed to be > "
"n_times. If you want zero-padding, you have to set "
f"n_per_seg to relevant length. Got n_fft of {n_fft} while"
f" signal length is {n}."
)
n_per_seg = n_fft if n_per_seg is None or n_per_seg > n_fft else n_per_seg
n_per_seg = n if n_per_seg > n else n_per_seg
if n_overlap >= n_per_seg:
raise ValueError(
"n_overlap cannot be greater than n_per_seg (or n_fft). Got n_overlap "
f"of {n_overlap} while n_per_seg is {n_per_seg}."
)
return n_fft, n_per_seg, n_overlap
@verbose
def psd_array_welch(
x,
sfreq,
fmin=0,
fmax=np.inf,
n_fft=256,
n_overlap=0,
n_per_seg=None,
n_jobs=None,
average="mean",
window="hamming",
remove_dc=True,
*,
output="power",
verbose=None,
):
"""Compute power spectral density (PSD) using Welch's method.
Welch's method is described in :footcite:t:`Welch1967`.
Parameters
----------
x : array, shape=(..., n_times)
The data to compute PSD from.
sfreq : float
The sampling frequency.
fmin : float
The lower frequency of interest.
fmax : float
The upper frequency of interest.
n_fft : int
The length of FFT used, must be ``>= n_per_seg`` (default: 256).
The segments will be zero-padded if ``n_fft > n_per_seg``.
n_overlap : int
The number of points of overlap between segments. Will be adjusted
to be <= n_per_seg. The default value is 0.
n_per_seg : int | None
Length of each Welch segment (windowed with a Hamming window). Defaults
to None, which sets n_per_seg equal to n_fft.
%(n_jobs)s
%(average_psd)s
.. versionadded:: 0.19.0
%(window_psd)s
.. versionadded:: 0.22.0
%(remove_dc)s
output : str
The format of the returned ``psds`` array, ``'complex'`` or
``'power'``:
* ``'power'`` : the power spectral density is returned.
* ``'complex'`` : the complex fourier coefficients are returned per
window.
.. versionadded:: 1.4.0
%(verbose)s
Returns
-------
psds : ndarray, shape (..., n_freqs) or (..., n_freqs, n_segments)
The power spectral densities. If ``average='mean`` or
``average='median'``, the returned array will have the same shape
as the input data plus an additional frequency dimension.
If ``average=None``, the returned array will have the same shape as
the input data plus two additional dimensions corresponding to
frequencies and the unaggregated segments, respectively.
freqs : ndarray, shape (n_freqs,)
The frequencies.
Notes
-----
.. versionadded:: 0.14.0
References
----------
.. footbibliography::
"""
_check_option("average", average, (None, False, "mean", "median"))
_check_option("output", output, ("power", "complex"))
detrend = "constant" if remove_dc else False
mode = "complex" if output == "complex" else "psd"
n_fft = _ensure_int(n_fft, "n_fft")
n_overlap = _ensure_int(n_overlap, "n_overlap")
if n_per_seg is not None:
n_per_seg = _ensure_int(n_per_seg, "n_per_seg")
if average is False:
average = None
dshape = x.shape[:-1]
n_times = x.shape[-1]
x = x.reshape(-1, n_times)
# Prep the PSD
n_fft, n_per_seg, n_overlap = _check_nfft(n_times, n_fft, n_per_seg, n_overlap)
win_size = n_fft / float(sfreq)
logger.info(f"Effective window size : {win_size:0.3f} (s)")
freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft)
freq_mask = (freqs >= fmin) & (freqs <= fmax)
if not freq_mask.any():
raise ValueError(f"No frequencies found between fmin={fmin} and fmax={fmax}")
freq_sl = slice(*(np.where(freq_mask)[0][[0, -1]] + [0, 1]))
del freq_mask
freqs = freqs[freq_sl]
# Parallelize across first N-1 dimensions
logger.debug(
f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
f"{n_overlap} overlap and {window} window"
)
parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
_func = partial(
spectrogram,
detrend=detrend,
noverlap=n_overlap,
nperseg=n_per_seg,
nfft=n_fft,
fs=sfreq,
window=window,
mode=mode,
)
if np.any(np.isnan(x)):
good_mask = ~np.isnan(x)
# NaNs originate from annot, so must match for all channels. Note that we CANNOT
# use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
# weights reflect the number of samples used from each span. For spans longer
# than `n_per_seg`, trailing samples may be discarded. For spans shorter than
# `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically
# reduces `n_per_seg` to match the span length (with a warning).
step = n_per_seg - n_overlap
span_lengths = [span.shape[-1] for span in x_splits]
weights = [
w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths
]
agg_func = partial(np.average, weights=weights)
if n_jobs > 1:
logger.info(
f"Data split into {len(x_splits)} (probably unequal) chunks due to "
'"bad_*" annotations. Parallelization may be sub-optimal.'
)
if (np.array(span_lengths) < n_per_seg).any():
logger.info(
"At least one good data span is shorter than n_per_seg, and will be "
"analyzed with a shorter window than the rest of the file."
)
def func(*args, **kwargs):
# swallow SciPy warnings caused by short good data spans
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore",
module="scipy",
category=UserWarning,
message=r"nperseg = \d+ is greater than input length",
)
return _func(*args, **kwargs)
else:
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
agg_func = np.concatenate
func = _func
f_spect = parallel(
my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
for d in x_splits
)
psds = agg_func(f_spect, axis=0)
shape = dshape + (len(freqs),)
if average is None:
shape = shape + (-1,)
psds.shape = shape
return psds, freqs