# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import platform
import matplotlib.pyplot as plt
import numpy as np
import pytest
from mne import Epochs, EpochsArray, create_info
from mne.datasets import testing
from mne.event import make_fixed_length_events
from mne.utils import _record_warnings
from mne.viz import plot_drop_log
def test_plot_epochs_not_preloaded(epochs_unloaded, browser_backend):
"""Test plotting non-preloaded epochs."""
if platform.machine() == "arm64":
pytest.xfail("Flakey verbose behavior on macOS arm64")
assert epochs_unloaded._data is None
epochs_unloaded.plot()
assert epochs_unloaded._data is None
def test_plot_epochs_basic(epochs, epochs_full, noise_cov_io, capsys, browser_backend):
"""Test epoch plotting."""
assert len(epochs.events) == 1
with epochs.info._unlock():
epochs.info["lowpass"] = 10.0 # allow heavy decim during plotting
fig = epochs.plot(scalings=None, title="Epochs")
ticks = fig._get_ticklabels("x")
assert ticks == ["2"]
browser_backend._close_all()
# covariance / whitening
assert len(noise_cov_io["names"]) == 366 # all channels
assert noise_cov_io["bads"] == []
assert epochs.info["bads"] == [] # all good
with pytest.warns(RuntimeWarning, match="projection"):
epochs.plot(noise_cov=noise_cov_io)
browser_backend._close_all()
# add a channel to the epochs.info['bads']
epochs.info["bads"] = [epochs.ch_names[0]]
with pytest.warns(RuntimeWarning, match="projection"):
epochs.plot(noise_cov=noise_cov_io)
browser_backend._close_all()
# add a channel to cov['bads']
noise_cov_io["bads"] = [epochs.ch_names[1]]
with _record_warnings(), pytest.warns(RuntimeWarning, match="projection"):
epochs.plot(noise_cov=noise_cov_io)
browser_backend._close_all()
# have a data channel missing from the covariance
noise_cov_io["names"] = noise_cov_io["names"][:306]
noise_cov_io["data"] = noise_cov_io["data"][:306][:306]
with _record_warnings(), pytest.warns(RuntimeWarning, match="projection"):
epochs.plot(noise_cov=noise_cov_io)
browser_backend._close_all()
# other options
fig = epochs[0].plot(picks=[0, 2, 3], scalings=None)
fig._fake_keypress("escape")
with pytest.raises(ValueError, match="No appropriate channels found"):
epochs.plot(picks=[])
# gh-5906
assert len(epochs_full) == 7
epochs_full.info["bads"] = [epochs_full.ch_names[0]]
capsys.readouterr()
# test title error handling
with pytest.raises(TypeError, match="title must be None or a string, got"):
epochs_full.plot(title=7)
# test auto-generated title, and selection mode
epochs_full.plot(group_by="selection", title="")
@pytest.mark.parametrize(
"scalings", (dict(mag=1e-12, grad=1e-11, stim="auto"), None, "auto")
)
def test_plot_epochs_scalings(epochs, scalings, browser_backend):
"""Test the valid options for scalings."""
epochs.plot(scalings=scalings)
def test_plot_epochs_colors(epochs, browser_backend):
"""Test epoch_colors, for compatibility with autoreject."""
epoch_colors = [["r"] * len(epochs.ch_names) for _ in range(len(epochs.events))]
epochs.plot(epoch_colors=epoch_colors)
with pytest.raises(ValueError, match="length equal to the number of epo"):
# epochs obj has only 1 epoch
epochs.plot(epoch_colors=[["r"], ["b"]])
with pytest.raises(ValueError, match=r"epoch colors for epoch \d+ has"):
# need 1 color for each channel
epochs.plot(epoch_colors=[["r"]])
# also test event_color
epochs.plot(event_color="b")
def test_plot_epochs_scale_bar(epochs, browser_backend):
"""Test scale bar for epochs."""
fig = epochs.plot()
texts = fig._get_scale_bar_texts()
# mag & grad in this instance
if browser_backend.name == "pyqtgraph":
assert len(texts) == 2
wants = ("800.0 fT/cm", "2000.0 fT")
elif browser_backend.name == "matplotlib":
assert len(texts) == 4
wants = ("800.0 fT/cm", "0.55 s", "2000.0 fT", "0.55 s")
assert texts == wants
def test_plot_epochs_clicks(epochs, epochs_full, capsys, browser_backend):
"""Test plot_epochs mouse interaction."""
fig = epochs.plot(events=True)
x = fig.mne.traces[0].get_xdata()[3]
y = fig.mne.traces[0].get_ydata()[3]
n_epochs = len(epochs)
epoch_num = fig.mne.inst.selection[0]
# test (un)marking bad epochs
fig._fake_click((x, y), xform="data") # mark a bad epoch
assert epoch_num in fig.mne.bad_epochs
fig._fake_click((x, y), xform="data") # unmark it
assert epoch_num not in fig.mne.bad_epochs
fig._fake_click((x, y), xform="data") # mark it bad again
assert epoch_num in fig.mne.bad_epochs
# test vline
fig._fake_keypress("escape") # close and drop epochs
fig._close_event() # XXX workaround, MPL Agg doesn't trigger close event
assert n_epochs - 1 == len(epochs)
# test marking bad channels
# need more than 1 epoch this time
fig = epochs_full.plot(n_epochs=3)
first_ch = fig._get_ticklabels("y")[0]
assert first_ch not in fig.mne.info["bads"]
fig._click_ch_name(ch_index=0, button=1) # click ch name to mark bad
assert first_ch in fig.mne.info["bads"]
# test clicking scrollbars
fig._fake_click((0.5, 0.5), ax=fig.mne.ax_vscroll)
fig._fake_click((0.5, 0.5), ax=fig.mne.ax_hscroll)
# test moving bad epoch offscreen
fig._fake_keypress("right") # move right
x = fig.mne.traces[0].get_xdata()[-3]
y = fig.mne.traces[0].get_ydata()[-3]
fig._fake_click((x, y), xform="data") # mark a bad epoch
fig._fake_keypress("left") # move back
out, err = capsys.readouterr()
assert "out of bounds" not in out
assert "out of bounds" not in err
fig._fake_keypress("escape")
fig._close_event() # XXX workaround, MPL Agg doesn't trigger close event
assert len(epochs_full) == 6
# test rightclick → image plot
fig = epochs_full.plot()
fig._click_ch_name(ch_index=0, button=3) # show image plot
assert len(fig.mne.child_figs) == 1
# test scroll wheel
fig._fake_scroll(0.5, 0.5, -0.5) # scroll down
fig._fake_scroll(0.5, 0.5, 0.5) # scroll up
def test_plot_epochs_keypresses(epochs_full, browser_backend):
"""Test plot_epochs keypress interaction."""
# we need more than 1 epoch
epochs_full.drop_bad(dict(mag=4e-12)) # for histogram plot coverage
fig = epochs_full.plot(n_epochs=3)
# make sure green vlines are visible first (for coverage)
sample_idx = len(epochs_full.times) // 2 # halfway through the first epoch
x = fig.mne.traces[0].get_xdata()[sample_idx]
y = (
fig.mne.traces[0].get_ydata()[sample_idx]
+ fig.mne.traces[1].get_ydata()[sample_idx]
) / 2
fig._fake_click([x, y], xform="data") # click between traces
# test keys
keys = (
"pagedown",
"down",
"up",
"down",
"right",
"left",
"-",
"+",
"=",
"d",
"d",
"pageup",
"home",
"shift+right",
"end",
"shift+left",
"z",
"z",
"s",
"s",
"?",
"h",
"j",
"b",
)
for key in keys * 2: # test twice → once in normal, once in butterfly view
fig._fake_keypress(key)
fig._fake_click([x, y], xform="data", button=3) # remove vlines
def _get_event_lines_and_texts(fig):
"""Get event lines and labels (helper function)."""
lines = fig.mne.event_lines
texts = fig.mne.event_texts
if hasattr(lines, "get_segments"): # matplotlib backend
lines = lines.get_segments()
texts = [t.get_text() for t in texts]
return lines, texts
@pytest.mark.parametrize(
"event_id,expected_texts",
[
(False, set("123")),
(True, set("abc")),
(dict(f=1), set("fbc")),
(dict(a=1), set("abc")),
],
)
def test_plot_overlapping_epochs_with_events(browser_backend, event_id, expected_texts):
"""Test drawing of event lines in overlapping epochs."""
data = np.zeros(shape=(3, 2, 100)) # 3 epochs, 2 channels, 100 samples
sfreq = 100
info = create_info(ch_names=("a", "b"), ch_types=("misc", "misc"), sfreq=sfreq)
# 90% overlap, so all 3 events should appear in all 3 epochs when plotted:
events = np.column_stack(([40, 50, 60], [0, 0, 0], [1, 2, 3]))
epochs = EpochsArray(
data, info, tmin=-0.4, events=events, event_id=dict(a=1, b=2, c=3)
)
fig = epochs.plot(events=events, picks="misc", event_id=event_id)
# check that the event lines are there and the labels are correct
lines, texts = _get_event_lines_and_texts(fig)
assert len(lines) == len(epochs) * len(events)
# TODO: Qt browser doesn't show event names, only integers
if browser_backend.name == "matplotlib":
assert set(texts) == expected_texts
# plot one epoch with its defining event plus events at its first & last sample
# (regression test for https://mne.discourse.group/t/6334)
events = np.vstack(([[0, 0, 4]], events[[0]], [[99, 0, 4]]))
fig = epochs[0].plot(events=events, picks="misc", event_id=event_id)
expected_texts.add("4")
for text in ("2", "3", "b", "c"):
expected_texts.discard(text)
lines, texts = _get_event_lines_and_texts(fig)
assert len(lines) == len(events)
# TODO: Qt browser doesn't show event names, only integers
if browser_backend.name == "matplotlib":
assert set(texts) == expected_texts
def test_epochs_plot_sensors(epochs):
"""Test sensor plotting."""
epochs.plot_sensors()
def test_plot_epochs_nodata(browser_backend):
"""Test plotting of epochs when no data channels are present."""
data = np.random.RandomState(0).randn(10, 2, 1000)
info = create_info(2, 1000.0, "stim")
epochs = EpochsArray(data, info)
with pytest.raises(ValueError, match="consider passing picks explicitly"):
epochs.plot()
@pytest.mark.slowtest
def test_plot_epochs_image(epochs):
"""Test plotting of epochs image."""
figs = epochs.plot_image()
assert len(figs) == 2 # one fig per ch_type (test data has mag, grad)
assert len(plt.get_fignums()) == 2
figs = epochs.plot_image()
assert len(figs) == 2
assert len(plt.get_fignums()) == 4 # should create new figures
epochs.plot_image(picks="mag", sigma=0.1)
epochs.plot_image(picks=[0, 1], combine="mean", ts_args=dict(show_sensors=False))
epochs.plot_image(
picks=[1], order=[0], overlay_times=[0.1], vmin=0.01, title="test"
)
plt.close("all")
epochs.plot_image(picks=[1], overlay_times=[0.1], vmin=-0.001, vmax=0.001)
plt.close("all")
epochs.plot_image(picks=[1], vmin=lambda x: x.min())
# test providing figure
fig, axs = plt.subplots(3, 1)
epochs.plot_image(picks=[1], fig=fig)
# test providing axes instance
epochs.plot_image(picks=[1], axes=axs[0], evoked=False, colorbar=False)
plt.close("all")
# test order=callable
epochs.plot_image(
picks=[0, 1], order=lambda times, data: np.arange(len(data))[::-1]
)
# test warning
with (
_record_warnings(),
pytest.warns(RuntimeWarning, match="Only one channel in group"),
):
epochs.plot_image(picks=[1], combine="mean")
# group_by should be a dict
with pytest.raises(TypeError, match="dict or None"):
epochs.plot_image(group_by="foo")
# units and scalings keys must match
with pytest.raises(ValueError, match="Scalings and units must have the"):
epochs.plot_image(units=dict(hi=1), scalings=dict(ho=1))
plt.close("all")
# test invert_y
epochs.plot_image(ts_args=dict(invert_y=True))
# can't combine different sensor types
with pytest.raises(ValueError, match="Cannot combine sensors of differ"):
epochs.plot_image(group_by=dict(foo=[0, 1, 2]))
# can't pass both fig and axes
with pytest.raises(ValueError, match='one of "fig" or "axes" must be'):
epochs.plot_image(fig="foo", axes="bar")
# wrong number of axes in fig
with pytest.raises(ValueError, match='"fig" must contain . axes, got .'):
epochs.plot_image(fig=plt.figure())
# only 1 group allowed when fig is passed
with pytest.raises(ValueError, match='"group_by" can only have one group'):
fig, axs = plt.subplots(3, 1)
epochs.plot_image(fig=fig, group_by=dict(foo=[0, 1], bar=[5, 6]))
del fig, axs
plt.close("all")
# must pass correct number of axes (1, 2, or 3)
with pytest.raises(ValueError, match="is a list, can only plot one group"):
fig, axs = plt.subplots(1, 3)
epochs.plot_image(axes=axs)
for length, kwargs in (
[3, dict()],
[2, dict(evoked=False)],
[2, dict(colorbar=False)],
[1, dict(evoked=False, colorbar=False)],
):
fig, axs = plt.subplots(1, length + 1)
epochs.plot_image(picks="mag", axes=axs[:length], **kwargs)
with pytest.raises(ValueError, match='"axes" must be length ., got .'):
epochs.plot_image(picks="mag", axes=axs, **kwargs)
plt.close("all")
# mismatch between axes dict keys and group_by dict keys
with pytest.raises(ValueError, match='must match the keys in "group_by"'):
epochs.plot_image(axes=dict())
# wrong number of axes in dict
match = 'each value in "axes" must be a list of . axes, got .'
with pytest.raises(ValueError, match=match):
epochs.plot_image(
axes=dict(foo=axs[:2], bar=axs[:3]), group_by=dict(foo=[0, 1], bar=[5, 6])
)
# bad value of "combine"
with pytest.raises(ValueError, match='"combine" must be None, a callable'):
epochs.plot_image(combine="foo")
# mismatched picks and overlay_times
with pytest.raises(ValueError, match="size of overlay_times parameter"):
epochs.plot_image(picks=[1], overlay_times=[0.1, 0.2])
# bad overlay times
with pytest.warns(RuntimeWarning, match="fall outside"):
epochs.plot_image(overlay_times=[999.0])
# mismatched picks and order
with pytest.raises(ValueError, match="must match the length of the data"):
epochs.plot_image(picks=[1], order=[0, 1])
# with a ref MEG channel (that we "convert" from a grad channel)
with pytest.warns(RuntimeWarning, match=".* from T/m to T.$"):
epochs.set_channel_types({epochs.ch_names[0]: "ref_meg"})
epochs.plot_image()
plt.close("all")
def test_plot_epochs_image_emg():
"""Test plotting epochs image with EMG."""
info = create_info(["EMG 001"], sfreq=100, ch_types="emg")
data = np.ones((2, 1, 10))
epochs = EpochsArray(data=data, info=info)
epochs.plot_image("EMG 001", ts_args={"show_sensors": False})
def test_plot_drop_log(epochs_unloaded):
"""Test plotting a drop log."""
with pytest.raises(ValueError, match="bad epochs have not yet been"):
epochs_unloaded.plot_drop_log()
epochs_unloaded.drop_bad()
epochs_unloaded.plot_drop_log()
plot_drop_log((("One",), (), ()))
plot_drop_log((("One",), ("Two",), ()))
plot_drop_log((("One",), ("One", "Two"), ()))
for arg in ([], ([],), (1,)):
with pytest.raises(TypeError, match="tuple of tuple of str"):
plot_drop_log(arg)
plt.close("all")
def test_plot_psd_epochs(epochs):
"""Test plotting epochs psd (+topomap)."""
spectrum = epochs.compute_psd()
old_defaults = dict(picks="data", exclude="bads")
spectrum.plot(average=True, amplitude=False, spatial_colors=False, **old_defaults)
spectrum.plot(average=False, amplitude=False, spatial_colors=True, **old_defaults)
spectrum.plot(average=False, amplitude=False, spatial_colors=False, **old_defaults)
# test plot_psd_topomap errors
with pytest.raises(RuntimeError, match="No frequencies in band"):
spectrum.plot_topomap(bands=dict(foo=(0, 0.01)))
plt.close("all")
# test defaults
fig = spectrum.plot_topomap()
assert len(fig.axes) == 10 # default: 5 bands (δ, θ, α, β, γ) + colorbars
# test joint vlim
fig = spectrum.plot_topomap(vlim="joint")
vmin_0 = fig.axes[0].images[0].norm.vmin
vmax_0 = fig.axes[0].images[0].norm.vmax
assert all(vmin_0 == ax.images[0].norm.vmin for ax in fig.axes[1:5])
assert all(vmax_0 == ax.images[0].norm.vmax for ax in fig.axes[1:5])
# test support for single-bin bands and old-style list-of-tuple input
fig = spectrum.plot_topomap(bands=[(20, "20 Hz"), (15, 25, "15-25 Hz")])
# test with a flat channel
err_str = f"for channel {epochs.ch_names[2]}"
epochs.get_data(copy=False)[0, 2, :] = 0
for dB in [True, False]:
with _record_warnings(), pytest.warns(UserWarning, match=err_str):
epochs.compute_psd().plot(dB=dB)
def test_plot_psdtopo_nirs(fnirs_epochs):
"""Test plotting of PSD topography for nirs data."""
bands = {"0.2 Hz": 0.2, "0.4 Hz": 0.4, "0.8 Hz": 0.8}
fig = fnirs_epochs.compute_psd().plot_topomap(bands=bands)
assert len(fig.axes) == 6 # 3 band x (plot + cmap)
@testing.requires_testing_data
def test_plot_epochs_ctf(raw_ctf, browser_backend):
"""Test of basic CTF plotting."""
raw_ctf.pick(
[
"UDIO001",
"UPPT001",
"SCLK01-177",
"BG1-4304",
"MLC11-4304",
"EEG058",
"UADC007-4302",
],
)
evts = make_fixed_length_events(raw_ctf)
epochs = Epochs(raw_ctf, evts, preload=True)
epochs.plot()
browser_backend._close_all()
# test butterfly
fig = epochs.plot(butterfly=True)
# leave fullscreen testing to Raw / _figure abstraction (too annoying here)
keys = (
"b",
"b",
"pagedown",
"down",
"up",
"down",
"right",
"left",
"-",
"+",
"=",
"d",
"d",
"pageup",
"home",
"end",
"z",
"z",
"s",
"s",
"?",
"h",
"j",
)
for key in keys:
fig._fake_keypress(key)
fig._fake_scroll(0.5, 0.5, -0.5) # scroll down
fig._fake_scroll(0.5, 0.5, 0.5) # scroll up
fig._resize_by_factor(1)
fig._fake_keypress("escape") # close and drop epochs
@pytest.mark.slowtest
@testing.requires_testing_data
def test_plot_psd_epochs_ctf(raw_ctf):
"""Test plotting CTF epochs psd (+topomap)."""
evts = make_fixed_length_events(raw_ctf)
epochs = Epochs(raw_ctf, evts, preload=True)
old_defaults = dict(picks="data", exclude="bads")
# EEG060 is flat in this dataset
with _record_warnings(), pytest.warns(UserWarning, match="for channel EEG060"):
spectrum = epochs.compute_psd()
for dB in [True, False]:
spectrum.plot(dB=dB)
spectrum.drop_channels(["EEG060"])
spectrum.plot(spatial_colors=False, average=False, amplitude=False, **old_defaults)
with pytest.raises(RuntimeError, match="No frequencies in band"):
spectrum.plot_topomap(bands=[(0, 0.01, "foo")])
spectrum.plot_topomap()
def test_plot_epochs_selection_butterfly(raw, browser_backend):
"""Test that using selection and butterfly works."""
events = make_fixed_length_events(raw)[:1]
epochs = Epochs(raw, events, tmin=0, tmax=0.5, preload=True, baseline=None)
assert len(epochs) == 1
epochs.plot(group_by="selection", butterfly=True)