# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from mne import (
SourceEstimate,
pick_events,
read_cov,
read_dipole,
read_events,
read_evokeds,
read_source_spaces,
)
from mne.chpi import compute_chpi_snr
from mne.datasets import testing
from mne.filter import create_filter
from mne.io import read_raw_fif
from mne.minimum_norm import read_inverse_operator
from mne.time_frequency import CrossSpectralDensity
from mne.utils import _record_warnings
from mne.viz import (
plot_bem,
plot_chpi_snr,
plot_csd,
plot_events,
plot_filter,
plot_snr_estimate,
plot_source_spectrogram,
)
from mne.viz.misc import _handle_event_colors
from mne.viz.utils import _get_color_list
data_path = testing.data_path(download=False)
subjects_dir = data_path / "subjects"
src_fname = subjects_dir / "sample" / "bem" / "sample-oct-6-src.fif"
inv_fname = (
data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-meg-inv.fif"
)
evoked_fname = data_path / "MEG" / "sample" / "sample_audvis-ave.fif"
dip_fname = data_path / "MEG" / "sample" / "sample_audvis_trunc_set1.dip"
chpi_fif_fname = data_path / "SSS" / "test_move_anon_raw.fif"
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
cov_fname = base_dir / "test-cov.fif"
event_fname = base_dir / "test-eve.fif"
def _get_raw():
"""Get raw data."""
return read_raw_fif(raw_fname, preload=True)
def _get_events():
"""Get events."""
return read_events(event_fname)
def test_plot_filter():
"""Test filter plotting."""
l_freq, h_freq, sfreq = 2.0, 40.0, 1000.0
data = np.zeros(5000)
freq = [0, 2, 40, 50, 500]
gain = [0, 1, 1, 0, 0]
h = create_filter(data, sfreq, l_freq, h_freq, fir_design="firwin2")
plot_filter(h, sfreq)
plt.close("all")
plot_filter(h, sfreq, freq, gain)
plt.close("all")
iir = create_filter(data, sfreq, l_freq, h_freq, method="iir")
plot_filter(iir, sfreq)
plt.close("all")
iir = create_filter(
data, sfreq, l_freq, h_freq, method="iir", iir_params={"output": "ba"}
)
plot_filter(iir, sfreq, compensate=True)
plt.close("all")
iir = create_filter(
data, sfreq, l_freq, h_freq, method="iir", iir_params={"output": "sos"}
)
plot_filter(iir, sfreq, compensate=True)
plt.close("all")
plot_filter(iir, sfreq, freq, gain)
plt.close("all")
iir_ba = create_filter(
data, sfreq, l_freq, h_freq, method="iir", iir_params=dict(output="ba")
)
plot_filter(iir_ba, sfreq, freq, gain)
plt.close("all")
fig = plot_filter(h, sfreq, freq, gain, fscale="linear")
assert len(fig.axes) == 3
plt.close("all")
fig = plot_filter(h, sfreq, freq, gain, fscale="linear", plot=("time", "delay"))
assert len(fig.axes) == 2
plt.close("all")
fig = plot_filter(
h, sfreq, freq, gain, fscale="linear", plot=["magnitude", "delay"]
)
assert len(fig.axes) == 2
plt.close("all")
fig = plot_filter(h, sfreq, freq, gain, fscale="linear", plot="magnitude")
assert len(fig.axes) == 1
plt.close("all")
fig = plot_filter(h, sfreq, freq, gain, fscale="linear", plot=("magnitude"))
assert len(fig.axes) == 1
plt.close("all")
with pytest.raises(ValueError, match="Invalid value for the .plot"):
plot_filter(h, sfreq, freq, gain, plot=("turtles"))
_, axes = plt.subplots(1)
fig = plot_filter(h, sfreq, freq, gain, plot=("magnitude"), axes=axes)
assert len(fig.axes) == 1
_, axes = plt.subplots(2)
fig = plot_filter(h, sfreq, freq, gain, plot=("magnitude", "delay"), axes=axes)
assert len(fig.axes) == 2
plt.close("all")
_, axes = plt.subplots(1)
with pytest.raises(ValueError, match="Length of axes"):
plot_filter(h, sfreq, freq, gain, plot=("magnitude", "delay"), axes=axes)
def test_plot_cov():
"""Test plotting of covariances."""
raw = _get_raw()
cov = read_cov(cov_fname)
with pytest.warns(RuntimeWarning, match="projection"):
fig1, fig2 = cov.plot(raw.info, proj=True, exclude=raw.ch_names[6:])
# test complex numbers
cov["data"] = cov.data * (1 + 1j)
fig1, fig2 = cov.plot(raw.info)
@testing.requires_testing_data
def test_plot_bem():
"""Test plotting of BEM contours."""
pytest.importorskip("nibabel")
with pytest.raises(OSError, match="MRI file .* not found"):
plot_bem(subject="bad-subject", subjects_dir=subjects_dir)
with pytest.raises(ValueError, match="Invalid value for the 'orientation"):
plot_bem(subject="sample", subjects_dir=subjects_dir, orientation="bad-ori")
with pytest.raises(ValueError, match="sorted 1D array"):
plot_bem(subject="sample", subjects_dir=subjects_dir, slices=[0, 500])
fig = plot_bem(
subject="sample",
subjects_dir=subjects_dir,
orientation="sagittal",
slices=[25, 50],
)
assert len(fig.axes) == 2
assert len(fig.axes[0].collections) == 3 # 3 BEM surfaces ...
fig = plot_bem(
subject="sample",
subjects_dir=subjects_dir,
orientation="coronal",
brain_surfaces="white",
)
assert len(fig.axes[0].collections) == 5 # 3 BEM surfaces + 2 hemis
fig = plot_bem(
subject="sample",
subjects_dir=subjects_dir,
orientation="coronal",
slices=[25, 50],
src=src_fname,
)
assert len(fig.axes[0].collections) == 4 # 3 BEM surfaces + 1 src contour
with pytest.raises(ValueError, match="MRI coordinates, got head"):
plot_bem(subject="sample", subjects_dir=subjects_dir, src=inv_fname)
def test_event_colors():
"""Test color assignment."""
events = pick_events(_get_events(), include=[1, 2])
unique_events = set(events[:, 2])
# make sure defaults work
colors = _handle_event_colors(None, unique_events, dict())
default_colors = _get_color_list()
assert colors[1] == default_colors[0]
# make sure custom color overrides default
colors = _handle_event_colors(
color_dict=dict(foo="k", bar="#facade"),
unique_events=unique_events,
event_id=dict(foo=1, bar=2),
)
assert colors[1] == "k"
assert colors[2] == "#facade"
def test_plot_events():
"""Test plotting events."""
event_labels = {"aud_l": 1, "aud_r": 2, "vis_l": 3, "vis_r": 4}
color = {1: "green", 2: "yellow", 3: "red", 4: "c"}
raw = _get_raw()
events = _get_events()
fig = plot_events(events, raw.info["sfreq"], raw.first_samp)
assert fig.axes[0].get_legend() is not None # legend even with no event_id
plot_events(events, raw.info["sfreq"], raw.first_samp, equal_spacing=False)
# Test plotting events without sfreq
plot_events(events, first_samp=raw.first_samp)
with pytest.warns(RuntimeWarning, match="will be ignored"):
fig = plot_events(
events, raw.info["sfreq"], raw.first_samp, event_id=event_labels
)
assert fig.axes[0].get_legend() is not None
with pytest.warns(RuntimeWarning, match="Color was not assigned"):
plot_events(events, raw.info["sfreq"], raw.first_samp, color=color)
with (
_record_warnings(),
pytest.warns(RuntimeWarning, match=r"vent \d+ missing from event_id"),
):
plot_events(
events,
raw.info["sfreq"],
raw.first_samp,
event_id=event_labels,
color=color,
)
multimatch = r"event \d+ missing from event_id|in the color dict but is"
with _record_warnings(), pytest.warns(RuntimeWarning, match=multimatch):
plot_events(
events,
raw.info["sfreq"],
raw.first_samp,
event_id={"aud_l": 1},
color=color,
)
extra_id = {"missing": 111}
with pytest.raises(ValueError, match="from event_id is not present in"):
plot_events(events, raw.info["sfreq"], raw.first_samp, event_id=extra_id)
with pytest.raises(RuntimeError, match="No usable event IDs"):
plot_events(
events,
raw.info["sfreq"],
raw.first_samp,
event_id=extra_id,
on_missing="ignore",
)
extra_id = {"aud_l": 1, "missing": 111}
with (
_record_warnings(),
pytest.warns(RuntimeWarning, match="from event_id is not present in"),
):
plot_events(
events,
raw.info["sfreq"],
raw.first_samp,
event_id=extra_id,
on_missing="warn",
)
with _record_warnings(), pytest.warns(RuntimeWarning, match="event 2 missing"):
plot_events(
events,
raw.info["sfreq"],
raw.first_samp,
event_id=extra_id,
on_missing="ignore",
)
events = events[events[:, 2] == 1]
assert len(events) > 0
plot_events(
events,
raw.info["sfreq"],
raw.first_samp,
event_id=extra_id,
on_missing="ignore",
)
with pytest.raises(ValueError, match="No events"):
plot_events(np.empty((0, 3)))
@testing.requires_testing_data
def test_plot_source_spectrogram():
"""Test plotting of source spectrogram."""
sample_src = read_source_spaces(
subjects_dir / "sample" / "bem" / "sample-oct-6-src.fif"
)
# dense version
vertices = [s["vertno"] for s in sample_src]
n_times = 5
n_verts = sum(len(v) for v in vertices)
stc_data = np.ones((n_verts, n_times))
stc = SourceEstimate(stc_data, vertices, 1, 1)
plot_source_spectrogram([stc, stc], [[1, 2], [3, 4]])
pytest.raises(ValueError, plot_source_spectrogram, [], [])
pytest.raises(
ValueError, plot_source_spectrogram, [stc, stc], [[1, 2], [3, 4]], tmin=0
)
pytest.raises(
ValueError, plot_source_spectrogram, [stc, stc], [[1, 2], [3, 4]], tmax=7
)
@pytest.mark.slowtest
@testing.requires_testing_data
def test_plot_snr():
"""Test plotting SNR estimate."""
inv = read_inverse_operator(inv_fname)
evoked = read_evokeds(evoked_fname, baseline=(None, 0))[0]
plot_snr_estimate(evoked, inv)
@testing.requires_testing_data
def test_plot_dipole_amplitudes():
"""Test plotting dipole amplitudes."""
dipoles = read_dipole(dip_fname)
dipoles.plot_amplitudes(show=False)
def test_plot_csd():
"""Test plotting of CSD matrices."""
csd = CrossSpectralDensity(
[1, 2, 3],
["CH1", "CH2"],
frequencies=[(10, 20)],
n_fft=1,
tmin=0,
tmax=1,
)
plot_csd(csd, mode="csd") # Plot cross-spectral density
plot_csd(csd, mode="coh") # Plot coherence
@pytest.mark.slowtest # Slow on Azure
@testing.requires_testing_data
def test_plot_chpi_snr():
"""Test plotting cHPI SNRs."""
raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes")
result = compute_chpi_snr(raw)
# test figure creation
fig = plot_chpi_snr(result)
assert len(fig.axes) == len(result) - 2
assert len(fig.axes[0].lines) == len(result["freqs"])
assert len(fig.legends) == 1
texts = [entry.get_text() for entry in fig.legends[0].get_texts()]
assert len(texts) == len(result["freqs"])
freqs = [float(text.split()[0]) for text in texts]
assert_array_equal(freqs, result["freqs"])
# test user-passed axes
_, axs = plt.subplots(2, 3)
_ = plot_chpi_snr(result, axes=axs.ravel())
# test error
_, axs = plt.subplots(5)
with pytest.raises(ValueError, match="a list of 6 axes, got length 5"):
_ = plot_chpi_snr(result, axes=axs.ravel())