# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from math import ceil
import numpy as np
from scipy.fft import irfft, rfft, rfftfreq
from ..utils import logger, verbose
@verbose
def stft(x, wsize, tstep=None, verbose=None):
"""STFT Short-Term Fourier Transform using a sine window.
The transformation is designed to be a tight frame that can be
perfectly inverted. It only returns the positive frequencies.
Parameters
----------
x : array, shape (n_signals, n_times)
Containing multi-channels signal.
wsize : int
Length of the STFT window in samples (must be a multiple of 4).
tstep : int
Step between successive windows in samples (must be a multiple of 2,
a divider of wsize and smaller than wsize/2) (default: wsize/2).
%(verbose)s
Returns
-------
X : array, shape (n_signals, wsize // 2 + 1, n_step)
STFT coefficients for positive frequencies with
``n_step = ceil(T / tstep)``.
See Also
--------
istft
stftfreq
"""
if not np.isrealobj(x):
raise ValueError("x is not a real valued array")
if x.ndim == 1:
x = x[None, :]
n_signals, T = x.shape
wsize = int(wsize)
# Errors and warnings
if wsize % 4:
raise ValueError("The window length must be a multiple of 4.")
if tstep is None:
tstep = wsize / 2
tstep = int(tstep)
if (wsize % tstep) or (tstep % 2):
raise ValueError(
"The step size must be a multiple of 2 and a divider of the window length."
)
if tstep > wsize / 2:
raise ValueError("The step size must be smaller than half the window length.")
n_step = int(ceil(T / float(tstep)))
n_freq = wsize // 2 + 1
logger.info(f"Number of frequencies: {n_freq}")
logger.info(f"Number of time steps: {n_step}")
X = np.zeros((n_signals, n_freq, n_step), dtype=np.complex128)
if n_signals == 0:
return X
# Defining sine window
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
win2 = win**2
swin = np.zeros((n_step - 1) * tstep + wsize)
for t in range(n_step):
swin[t * tstep : t * tstep + wsize] += win2
swin = np.sqrt(wsize * swin)
# Zero-padding and Pre-processing for edges
xp = np.zeros((n_signals, wsize + (n_step - 1) * tstep), dtype=x.dtype)
xp[:, (wsize - tstep) // 2 : (wsize - tstep) // 2 + T] = x
x = xp
for t in range(n_step):
# Framing
wwin = win / swin[t * tstep : t * tstep + wsize]
frame = x[:, t * tstep : t * tstep + wsize] * wwin[None, :]
# FFT
X[:, :, t] = rfft(frame)
return X
def istft(X, tstep=None, Tx=None):
"""ISTFT Inverse Short-Term Fourier Transform using a sine window.
Parameters
----------
X : array, shape (..., wsize / 2 + 1, n_step)
The STFT coefficients for positive frequencies.
tstep : int
Step between successive windows in samples (must be a multiple of 2,
a divider of wsize and smaller than wsize/2) (default: wsize/2).
Tx : int
Length of returned signal. If None Tx = n_step * tstep.
Returns
-------
x : array, shape (Tx,)
Array containing the inverse STFT signal.
See Also
--------
stft
"""
# Errors and warnings
X = np.asarray(X)
if X.ndim < 2:
raise ValueError(f"X must have ndim >= 2, got {X.ndim}")
n_win, n_step = X.shape[-2:]
signal_shape = X.shape[:-2]
if n_win % 2 == 0:
raise ValueError("The number of rows of the STFT matrix must be odd.")
wsize = 2 * (n_win - 1)
if tstep is None:
tstep = wsize / 2
if wsize % tstep:
raise ValueError(
"The step size must be a divider of two times the "
"number of rows of the STFT matrix minus two."
)
if wsize % 2:
raise ValueError("The step size must be a multiple of 2.")
if tstep > wsize / 2:
raise ValueError(
"The step size must be smaller than the number of "
"rows of the STFT matrix minus one."
)
if Tx is None:
Tx = n_step * tstep
T = n_step * tstep
x = np.zeros(signal_shape + (T + wsize - tstep,), dtype=np.float64)
if np.prod(signal_shape) == 0:
return x[..., :Tx]
# Defining sine window
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
# win = win / norm(win);
# Pre-processing for edges
swin = np.zeros(T + wsize - tstep, dtype=np.float64)
for t in range(n_step):
swin[t * tstep : t * tstep + wsize] += win**2
swin = np.sqrt(swin / wsize)
for t in range(n_step):
# IFFT
frame = irfft(X[..., t], wsize)
# Overlap-add
frame *= win / swin[t * tstep : t * tstep + wsize]
x[..., t * tstep : t * tstep + wsize] += frame
# Truncation
x = x[..., (wsize - tstep) // 2 : (wsize - tstep) // 2 + T + 1]
x = x[..., :Tx].copy()
return x
def stftfreq(wsize, sfreq=None): # noqa: D401
"""Compute frequencies of stft transformation.
Parameters
----------
wsize : int
Size of stft window.
sfreq : float
Sampling frequency. If None the frequencies are given between 0 and pi
otherwise it's given in Hz.
Returns
-------
freqs : array
The positive frequencies returned by stft.
See Also
--------
stft
istft
"""
freqs = rfftfreq(wsize)
if sfreq is not None:
freqs *= float(sfreq)
return freqs
def stft_norm2(X):
"""Compute L2 norm of STFT transform.
It takes into account that stft only return positive frequencies.
As we use tight frame this quantity is conserved by the stft.
Parameters
----------
X : 3D complex array
The STFT transforms
Returns
-------
norms2 : array
The squared L2 norm of every row of X.
"""
X2 = (X * X.conj()).real
# compute all L2 coefs and remove first and last frequency once.
norms2 = (
2.0 * X2.sum(axis=2).sum(axis=1)
- np.sum(X2[:, 0, :], axis=1)
- np.sum(X2[:, -1, :], axis=1)
)
return norms2
def stft_norm1(X):
"""Compute L1 norm of STFT transform.
It takes into account that stft only return positive frequencies.
Parameters
----------
X : 3D complex array
The STFT transforms
Returns
-------
norms : array
The L1 norm of every row of X.
"""
X_abs = np.abs(X)
# compute all L1 coefs and remove first and last frequency once.
norms = (
2.0 * X_abs.sum(axis=(1, 2))
- np.sum(X_abs[:, 0, :], axis=1)
- np.sum(X_abs[:, -1, :], axis=1)
)
return norms