# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import pickle
from itertools import product
from os import path as op
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from pytest import raises
import mne
from mne.channels import equalize_channels
from mne.proj import Projection
from mne.time_frequency import (
CrossSpectralDensity,
csd_array_fourier,
csd_array_morlet,
csd_array_multitaper,
csd_fourier,
csd_morlet,
csd_multitaper,
csd_tfr,
pick_channels_csd,
read_csd,
tfr_morlet,
)
from mne.time_frequency.csd import _sym_mat_to_vector, _vector_to_sym_mat
from mne.utils import sum_squared
base_dir = op.join(op.dirname(__file__), "..", "..", "io", "tests", "data")
raw_fname = op.join(base_dir, "test_raw.fif")
event_fname = op.join(base_dir, "test-eve.fif")
def _make_csd(add_proj=False):
"""Make a simple CrossSpectralDensity object."""
frequencies = [1.0, 2.0, 3.0, 4.0]
n_freqs = len(frequencies)
names = ["CH1", "CH2", "CH3"]
tmin, tmax = (0.0, 1.0)
data = np.arange(6.0 * n_freqs).reshape(n_freqs, 6).T
if add_proj:
proj_data = dict(col_names=names, row_names=None, data=np.ones((1, len(names))))
projs = [Projection(data=proj_data)]
else:
projs = None
return CrossSpectralDensity(data, names, frequencies, 1, tmin, tmax, projs=projs)
def test_csd():
"""Test constructing a CrossSpectralDensity."""
csd = CrossSpectralDensity(
[1, 2, 3], ["CH1", "CH2"], frequencies=1, n_fft=1, tmin=0, tmax=1
)
assert_array_equal(csd._data, [[1], [2], [3]]) # Conversion to 2D array
assert_array_equal(csd.frequencies, [1]) # Conversion to 1D array
# Channels don't match
raises(
ValueError,
CrossSpectralDensity,
[1, 2, 3],
["CH1", "CH2", "Too many!"],
tmin=0,
tmax=1,
frequencies=1,
n_fft=1,
)
raises(
ValueError,
CrossSpectralDensity,
[1, 2, 3],
["too little"],
tmin=0,
tmax=1,
frequencies=1,
n_fft=1,
)
# Frequencies don't match
raises(
ValueError,
CrossSpectralDensity,
[[1, 2], [3, 4], [5, 6]],
["CH1", "CH2"],
tmin=0,
tmax=1,
frequencies=1,
n_fft=1,
)
# Invalid dims
raises(
ValueError,
CrossSpectralDensity,
[[[1]]],
["CH1"],
frequencies=1,
n_fft=1,
tmin=0,
tmax=1,
)
def test_csd_repr():
"""Test string representation of CrossSpectralDensity."""
csd = _make_csd()
assert str(csd) == (
"<CrossSpectralDensity | n_channels=3, time=0.0 to "
"1.0 s, frequencies=1.0, 2.0, 3.0, 4.0 Hz.>"
)
assert str(csd.mean()) == (
"<CrossSpectralDensity | n_channels=3, "
"time=0.0 to 1.0 s, frequencies=1.0-4.0 Hz.>"
)
csd_binned = csd.mean(fmin=[1, 3], fmax=[2, 4])
assert str(csd_binned) == (
"<CrossSpectralDensity | n_channels=3, "
"time=0.0 to 1.0 s, frequencies=1.0-2.0, "
"3.0-4.0 Hz.>"
)
csd_binned = csd.mean(fmin=[1, 2], fmax=[1, 4])
assert str(csd_binned) == (
"<CrossSpectralDensity | n_channels=3, "
"time=0.0 to 1.0 s, frequencies=1.0, 2.0-4.0 "
"Hz.>"
)
csd_no_time = csd.copy()
csd_no_time.tmin = None
csd_no_time.tmax = None
assert str(csd_no_time) == (
"<CrossSpectralDensity | n_channels=3, time=unknown, "
"frequencies=1.0, 2.0, 3.0, 4.0 Hz.>"
)
def test_csd_mean():
"""Test averaging frequency bins of CrossSpectralDensity."""
csd = _make_csd()
# Test different ways to average across all frequencies
avg = [[9], [10], [11], [12], [13], [14]]
assert_array_equal(csd.mean()._data, avg)
assert_array_equal(csd.mean(fmin=None, fmax=4)._data, avg)
assert_array_equal(csd.mean(fmin=1, fmax=None)._data, avg)
assert_array_equal(csd.mean(fmin=0, fmax=None)._data, avg)
assert_array_equal(csd.mean(fmin=1, fmax=4)._data, avg)
# Test averaging across frequency bins
csd_binned = csd.mean(fmin=[1, 3], fmax=[2, 4])
assert_array_equal(
csd_binned._data,
[[3, 15], [4, 16], [5, 17], [6, 18], [7, 19], [8, 20]],
)
csd_binned = csd.mean(fmin=[1, 3], fmax=[1, 4])
assert_array_equal(
csd_binned._data,
[[0, 15], [1, 16], [2, 17], [3, 18], [4, 19], [5, 20]],
)
# This flag should be set after averaging
assert csd.mean()._is_sum
# Test construction of .frequency attribute
assert csd.mean().frequencies == [[1, 2, 3, 4]]
assert csd.mean(fmin=[1, 3], fmax=[2, 4]).frequencies == [[1, 2], [3, 4]]
# Test invalid inputs
raises(ValueError, csd.mean, fmin=1, fmax=[2, 3])
raises(ValueError, csd.mean, fmin=[1, 2], fmax=[3])
raises(ValueError, csd.mean, fmin=[1, 2], fmax=[1, 1])
# Taking the mean twice should raise an error
raises(RuntimeError, csd.mean().mean)
def test_csd_get_frequency_index():
"""Test the _get_frequency_index method of CrossSpectralDensity."""
csd = _make_csd()
assert csd._get_frequency_index(1) == 0
assert csd._get_frequency_index(2) == 1
assert csd._get_frequency_index(4) == 3
assert csd._get_frequency_index(0.9) == 0
assert csd._get_frequency_index(2.1) == 1
assert csd._get_frequency_index(4.1) == 3
# Frequency can be off by a maximum of 1
raises(IndexError, csd._get_frequency_index, csd.frequencies[-1] + 1.0001)
def test_csd_pick_frequency():
"""Test the pick_frequency method of CrossSpectralDensity."""
csd = _make_csd()
csd2 = csd.pick_frequency(freq=2)
assert csd2.frequencies == [2]
assert_array_equal(csd2.get_data(), [[6, 7, 8], [7, 9, 10], [8, 10, 11]])
csd2 = csd.pick_frequency(index=1)
assert csd2.frequencies == [2]
assert_array_equal(csd2.get_data(), [[6, 7, 8], [7, 9, 10], [8, 10, 11]])
# Nonexistent frequency
raises(IndexError, csd.pick_frequency, -1)
# Nonexistent index
raises(IndexError, csd.pick_frequency, index=10)
# Invalid parameters
raises(ValueError, csd.pick_frequency)
raises(ValueError, csd.pick_frequency, freq=2, index=1)
def test_csd_get_data():
"""Test the get_data method of CrossSpectralDensity."""
csd = _make_csd()
# CSD matrix corresponding to 2 Hz.
assert_array_equal(csd.get_data(frequency=2), [[6, 7, 8], [7, 9, 10], [8, 10, 11]])
# Mean CSD matrix
assert_array_equal(csd.mean().get_data(), [[9, 10, 11], [10, 12, 13], [11, 13, 14]])
# Average across frequency bins, select bin
assert_array_equal(
csd.mean(fmin=[1, 3], fmax=[2, 4]).get_data(index=1),
[[15, 16, 17], [16, 18, 19], [17, 19, 20]],
)
# Invalid inputs
raises(ValueError, csd.get_data)
raises(ValueError, csd.get_data, frequency=1, index=1)
raises(IndexError, csd.get_data, frequency=15)
raises(ValueError, csd.mean().get_data, frequency=1)
raises(IndexError, csd.mean().get_data, index=15)
def test_csd_save(tmp_path):
"""Test saving and loading a CrossSpectralDensity."""
pytest.importorskip("h5io")
csd = _make_csd(add_proj=True)
fname = op.join(str(tmp_path), "csd.h5")
csd.save(fname)
csd2 = read_csd(fname)
assert_array_equal(csd._data, csd2._data)
assert_array_equal(csd.frequencies, csd2.frequencies)
assert csd.tmin == csd2.tmin
assert csd.tmax == csd2.tmax
assert csd.ch_names == csd2.ch_names
assert csd._is_sum == csd2._is_sum
assert isinstance(csd2.projs[0], Projection)
def test_csd_pickle(tmp_path):
"""Test pickling and unpickling a CrossSpectralDensity."""
csd = _make_csd()
tempdir = str(tmp_path)
fname = op.join(tempdir, "csd.dat")
with open(fname, "wb") as f:
pickle.dump(csd, f)
with open(fname, "rb") as f:
csd2 = pickle.load(f) # nosec B301
assert_array_equal(csd._data, csd2._data)
assert csd.tmin == csd2.tmin
assert csd.tmax == csd2.tmax
assert csd.ch_names == csd2.ch_names
assert csd.frequencies == csd2.frequencies
assert csd._is_sum == csd2._is_sum
def test_pick_channels_csd():
"""Test selecting channels from a CrossSpectralDensity."""
csd = _make_csd()
csd = pick_channels_csd(csd, ["CH1", "CH3"])
assert csd.ch_names == ["CH1", "CH3"]
assert_array_equal(csd._data, [[0, 6, 12, 18], [2, 8, 14, 20], [5, 11, 17, 23]])
def test_sym_mat_to_vector():
"""Test converting between vectors and symmetric matrices."""
mat = np.array([[0, 1, 2, 3], [1, 4, 5, 6], [2, 5, 7, 8], [3, 6, 8, 9]])
assert_array_equal(_sym_mat_to_vector(mat), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
vec = np.arange(10)
assert_array_equal(
_vector_to_sym_mat(vec),
[[0, 1, 2, 3], [1, 4, 5, 6], [2, 5, 7, 8], [3, 6, 8, 9]],
)
# Test complex values: diagonals should be complex conjugates
comp_vec = np.arange(3) + 1j
assert_array_equal(
_vector_to_sym_mat(comp_vec),
[[0.0 + 0.0j, 1.0 + 1.0j], [1.0 - 1.0j, 2.0 + 0.0j]],
)
# Test preservation of data type
assert _sym_mat_to_vector(mat.astype(np.int8)).dtype == np.int8
assert _vector_to_sym_mat(vec.astype(np.int8)).dtype == np.int8
assert _sym_mat_to_vector(mat.astype(np.float16)).dtype == np.float16
assert _vector_to_sym_mat(vec.astype(np.float16)).dtype == np.float16
def _generate_coherence_data():
"""Create an epochs object with coherence at 22Hz between channels 1 and 3.
A base 10 Hz sine wave is generated for all channels, but with different
phases, which means no actual coherence. A 22Hz sine wave is laid on top
for channels 1 and 3, with the same phase, so there is coherence between
these channels.
"""
ch_names = ["CH1", "CH2", "CH3"]
sfreq = 50.0
info = mne.create_info(ch_names, sfreq, "eeg")
tstep = 1.0 / sfreq
n_samples = int(10 * sfreq) # 10 seconds of data
times = np.arange(n_samples) * tstep
events = np.array([[0, 1, 1]]) # one event
# Phases for the signals
phases = np.arange(info["nchan"]) * 0.3 * np.pi
# Generate 10 Hz sine waves with different phases
signal = np.vstack([np.sin(times * 2 * np.pi * 10 + phase) for phase in phases])
data = np.zeros((1, info["nchan"], n_samples))
data[0, :, :] = signal
# Generate 22Hz sine wave at the first and last electrodes with the same
# phase.
signal = np.sin(times * 2 * np.pi * 22)
data[0, [0, -1], :] += signal
return mne.EpochsArray(data, info, events, baseline=(0, times[-1]))
def _test_csd_matrix(csd):
"""Perform a suite of tests on a CSD matrix."""
# Check shape of the CSD matrix
n_chan = len(csd.ch_names)
assert n_chan == 3
assert csd.ch_names == ["CH1", "CH2", "CH3"]
n_freqs = len(csd.frequencies)
assert n_freqs == 3
assert csd._data.shape == (6, 3) # Only upper triangle of CSD matrix
# Extract CSD ndarrays. Diagonals are PSDs.
csd_10 = csd.get_data(index=0)
csd_22 = csd.get_data(index=2)
power_10 = np.diag(csd_10)
power_22 = np.diag(csd_22)
# Check if the CSD matrices are hermitian
assert np.all(np.tril(csd_10).T.conj() == np.triu(csd_10))
assert np.all(np.tril(csd_22).T.conj() == np.triu(csd_22))
# Off-diagonals show phase difference
assert np.abs(csd_10[0, 1].imag) > 0.4
assert np.abs(csd_10[0, 2].imag) > 0.4
assert np.abs(csd_10[1, 2].imag) > 0.4
# No phase differences at 22 Hz
assert np.all(np.abs(csd_22[0, 2].imag) < 1e-3)
# Test CSD between the two channels that have a 20Hz signal and the one
# that has only a 10 Hz signal
assert np.abs(csd_22[0, 2]) > np.abs(csd_22[0, 1])
assert np.abs(csd_22[0, 2]) > np.abs(csd_22[1, 2])
# Check that electrodes/frequency combinations with signal have more
# power than frequencies without signal.
power_15 = np.diag(csd.get_data(index=1))
assert np.all(power_10 > power_15)
assert np.all(power_22[[0, -1]] > power_15[[0, -1]])
def _test_fourier_multitaper_parameters(epochs, csd_epochs, csd_array):
"""Parameter tests for csd_*_fourier and csd_*_multitaper."""
raises(ValueError, csd_epochs, epochs, fmin=20, fmax=10)
raises(
ValueError,
csd_array,
epochs._data,
epochs.info["sfreq"],
epochs.tmin,
fmin=20,
fmax=10,
)
raises(ValueError, csd_epochs, epochs, fmin=20.11, fmax=20.19)
raises(
ValueError,
csd_array,
epochs._data,
epochs.info["sfreq"],
epochs.tmin,
fmin=20.11,
fmax=20.19,
)
raises(ValueError, csd_epochs, epochs, tmin=0.15, tmax=0.1)
raises(
ValueError,
csd_array,
epochs._data,
epochs.info["sfreq"],
epochs.tmin,
tmin=0.15,
tmax=0.1,
)
raises(ValueError, csd_epochs, epochs, tmin=-1, tmax=10)
raises(
ValueError,
csd_array,
epochs._data,
epochs.info["sfreq"],
epochs.tmin,
tmin=-1,
tmax=10,
)
raises(ValueError, csd_epochs, epochs, tmin=10, tmax=11)
raises(
ValueError,
csd_array,
epochs._data,
epochs.info["sfreq"],
epochs.tmin,
tmin=10,
tmax=11,
)
# Test checks for data types and sizes
diff_types = [np.random.randn(3, 5), "error"]
err_data = [np.random.randn(3, 5), np.random.randn(2, 4)]
raises(ValueError, csd_array, err_data, sfreq=1)
raises(ValueError, csd_array, diff_types, sfreq=1)
raises(ValueError, csd_array, np.random.randn(3), sfreq=1)
def test_csd_fourier():
"""Test computing cross-spectral density using short-term Fourier."""
epochs = _generate_coherence_data()
sfreq = epochs.info["sfreq"]
_test_fourier_multitaper_parameters(epochs, csd_fourier, csd_array_fourier)
# Compute CSDs using various parameters
times = [(None, None), (1, 9)]
as_arrays = [False, True]
parameters = product(times, as_arrays)
for (tmin, tmax), as_array in parameters:
if as_array:
csd = csd_array_fourier(
epochs.get_data(copy=False),
sfreq,
epochs.tmin,
fmin=9,
fmax=23,
tmin=tmin,
tmax=tmax,
ch_names=epochs.ch_names,
)
else:
csd = csd_fourier(epochs, fmin=9, fmax=23, tmin=tmin, tmax=tmax)
if tmin is None and tmax is None:
assert csd.tmin == 0 and csd.tmax == 9.98
else:
assert csd.tmin == tmin and csd.tmax == tmax
csd = csd.mean([9.9, 14.9, 21.9], [10.1, 15.1, 22.1])
_test_csd_matrix(csd)
# For the next test, generate a simple sine wave with a known power
times = np.arange(20 * sfreq) / sfreq # 20 seconds of signal
signal = np.sin(2 * np.pi * 10 * times)[None, None, :] # 10 Hz wave
signal_power_per_sample = sum_squared(signal) / len(times)
# Power per sample should not depend on time window length
for tmax in [12, 18]:
t_mask = times <= tmax
n_samples = sum(t_mask)
# Power per sample should not depend on number of FFT points
for add_n_fft in [0, 30]:
n_fft = n_samples + add_n_fft
csd = (
csd_array_fourier(signal, sfreq, tmax=tmax, n_fft=n_fft)
.sum()
.get_data()
)
first_samp = csd[0, 0]
fourier_power_per_sample = np.abs(first_samp) * sfreq / n_fft
assert abs(signal_power_per_sample - fourier_power_per_sample) < 0.001
def test_csd_multitaper():
"""Test computing cross-spectral density using multitapers."""
epochs = _generate_coherence_data()
sfreq = epochs.info["sfreq"]
_test_fourier_multitaper_parameters(epochs, csd_multitaper, csd_array_multitaper)
# Compute CSDs using various parameters
times = [(None, None), (1, 9)]
as_arrays = [False, True]
adaptives = [False, True]
parameters = product(times, as_arrays, adaptives)
for (tmin, tmax), as_array, adaptive in parameters:
if as_array:
csd = csd_array_multitaper(
epochs.get_data(copy=False),
sfreq,
epochs.tmin,
adaptive=adaptive,
fmin=9,
fmax=23,
tmin=tmin,
tmax=tmax,
ch_names=epochs.ch_names,
)
else:
csd = csd_multitaper(
epochs, adaptive=adaptive, fmin=9, fmax=23, tmin=tmin, tmax=tmax
)
if tmin is None and tmax is None:
assert csd.tmin == 0 and csd.tmax == 9.98
else:
assert csd.tmin == tmin and csd.tmax == tmax
csd = csd.mean([9.9, 14.9, 21.9], [10.1, 15.1, 22.1])
_test_csd_matrix(csd)
# Test equivalence with PSD
spectrum = epochs.compute_psd(fmin=1e-3, normalization="full") # omit DC
psd, psd_freqs = spectrum.get_data(return_freqs=True)
csd = csd_multitaper(epochs)
assert_allclose(psd_freqs, csd.frequencies)
csd = np.array([np.diag(csd.get_data(index=ii)) for ii in range(len(csd))]).T
assert_allclose(psd[0], csd)
# For the next test, generate a simple sine wave with a known power
times = np.arange(20 * sfreq) / sfreq # 20 seconds of signal
signal = np.sin(2 * np.pi * 10 * times)[None, None, :] # 10 Hz wave
signal_power_per_sample = sum_squared(signal) / len(times)
# Power per sample should not depend on time window length
for tmax in [12, 18]:
t_mask = times <= tmax
n_samples = sum(t_mask)
n_fft = len(times)
# Power per sample should not depend on number of tapers
for n_tapers in [1, 2, 5]:
bandwidth = sfreq / float(n_samples) * (n_tapers + 1)
csd_mt = (
csd_array_multitaper(
signal, sfreq, tmax=tmax, bandwidth=bandwidth, n_fft=n_fft
)
.sum()
.get_data()
)
mt_power_per_sample = np.abs(csd_mt[0, 0]) * sfreq / n_fft
assert abs(signal_power_per_sample - mt_power_per_sample) < 0.001
def test_csd_morlet():
"""Test computing cross-spectral density using Morlet wavelets."""
epochs = _generate_coherence_data()
sfreq = epochs.info["sfreq"]
# Compute CSDs by a variety of methods
freqs = [10, 15, 22]
n_cycles = [20, 30, 44]
times = [(None, None), (1, 9)]
as_arrays = [False, True]
parameters = product(times, as_arrays)
for (tmin, tmax), as_array in parameters:
if as_array:
csd = csd_array_morlet(
epochs.get_data(copy=False),
sfreq,
freqs,
t0=epochs.tmin,
n_cycles=n_cycles,
tmin=tmin,
tmax=tmax,
ch_names=epochs.ch_names,
)
else:
csd = csd_morlet(
epochs, frequencies=freqs, n_cycles=n_cycles, tmin=tmin, tmax=tmax
)
if tmin is None and tmax is None:
assert csd.tmin == 0 and csd.tmax == 9.98
else:
assert csd.tmin == tmin and csd.tmax == tmax
_test_csd_matrix(csd)
# CSD diagonals should contain PSD
tfr = tfr_morlet(epochs, freqs, n_cycles, return_itc=False)
power = np.mean(tfr.data, 2)
csd = csd_morlet(epochs, frequencies=freqs, n_cycles=n_cycles)
assert_allclose(csd._data[[0, 3, 5]] * sfreq, power)
# Test using plain convolution instead of FFT
csd = csd_morlet(epochs, frequencies=freqs, n_cycles=n_cycles, use_fft=False)
assert_allclose(csd._data[[0, 3, 5]] * sfreq, power)
# Test baselining warning
epochs_nobase = epochs.copy()
epochs_nobase.baseline = None
with epochs_nobase.info._unlock():
epochs_nobase.info["highpass"] = 0
with pytest.warns(RuntimeWarning, match="baseline"):
csd = csd_morlet(epochs_nobase, frequencies=[10], decim=20)
def test_equalize_channels():
"""Test equalization of channels for instances of CrossSpectralDensity."""
csd1 = _make_csd()
# TODO replace with `.pick()` when CSD objects get that method
csd2 = csd1.copy().pick_channels(["CH2", "CH1"], ordered=True)
csd1, csd2 = equalize_channels([csd1, csd2])
assert csd1.ch_names == ["CH1", "CH2"]
assert csd2.ch_names == ["CH1", "CH2"]
def test_csd_tfr():
"""Test computing cross-spectral density on time-frequency epochs."""
rng = np.random.default_rng(11)
n_epochs = 6
info = mne.io.read_info(raw_fname)
info = mne.pick_info(info, mne.pick_types(info, eeg=True))
freqs = np.arange(38, 40)
times = np.linspace(0, 1, int(round(info["sfreq"])))
data = rng.normal(size=(n_epochs, len(info.ch_names), times.size)) * 1e-6
epochs = mne.EpochsArray(data, info)
csd_test = csd_morlet(epochs, freqs, n_cycles=7, tmin=0.25, tmax=0.75)
epochs_tfr = tfr_morlet(
epochs, freqs, n_cycles=7, average=False, return_itc=False, output="complex"
)
csd = csd_tfr(epochs_tfr, tmin=0.25, tmax=0.75)
assert_allclose(csd._data, csd_test._data)
assert_array_equal(csd.frequencies, freqs)