# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib import gridspec
from matplotlib.collections import PolyCollection
from matplotlib.colors import same_color
from mpl_toolkits.axes_grid1.parasite_axes import HostAxes # spatial_colors
from numpy.testing import assert_allclose
import mne
from mne import (
Epochs,
compute_covariance,
compute_proj_evoked,
make_fixed_length_events,
read_cov,
read_events,
)
from mne._fiff.constants import FIFF
from mne.datasets import testing
from mne.io import read_raw_fif
from mne.stats.parametric import _parametric_ci
from mne.utils import _record_warnings, catch_logging
from mne.viz import plot_compare_evokeds, plot_evoked_white
from mne.viz.utils import _fake_click, _get_cmap
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
evoked_fname = base_dir / "test-ave.fif"
raw_fname = base_dir / "test_raw.fif"
raw_sss_fname = base_dir / "test_chpi_raw_sss.fif"
cov_fname = base_dir / "test-cov.fif"
event_name = base_dir / "test-eve.fif"
event_id, tmin, tmax = 1, -0.1, 0.1
ctf_fname = testing.data_path(download=False) / "CTF" / "testdata_ctf.ds"
# Use a subset of channels for plotting speed
# make sure we have a magnetometer and a pair of grad pairs for topomap.
default_picks = (
0,
1,
2,
3,
4,
6,
7,
61,
122,
183,
244,
305,
315,
316,
317,
318,
) # EEG channels
sel = (0, 7)
def _get_epochs(picks=default_picks):
"""Get epochs."""
raw = read_raw_fif(raw_fname)
raw.add_proj([], remove_existing=True)
events = read_events(event_name)
epochs = Epochs(
raw, events[:5], event_id, tmin, tmax, picks=picks, decim=10, verbose="error"
)
epochs.info["bads"] = [epochs.ch_names[-5], epochs.ch_names[-1]] # MEG # EEG
epochs.info.normalize_proj()
return epochs
def _get_epochs_delayed_ssp():
"""Get epochs with delayed SSP."""
raw = read_raw_fif(raw_fname)
events = read_events(event_name)
reject = dict(mag=4e-12)
epochs_delayed_ssp = Epochs(
raw,
events[:10],
event_id,
tmin,
tmax,
picks=default_picks,
proj="delayed",
reject=reject,
verbose="error",
)
epochs_delayed_ssp.info.normalize_proj()
return epochs_delayed_ssp
def test_plot_evoked_cov():
"""Test plot_evoked with noise_cov."""
evoked = _get_epochs().average()
cov = read_cov(cov_fname)
cov["projs"] = [] # avoid warnings
with pytest.warns(RuntimeWarning, match="No average EEG reference"):
evoked.plot(noise_cov=cov, time_unit="s")
with pytest.raises(TypeError, match="Covariance"):
evoked.plot(noise_cov=1.0, time_unit="s")
with pytest.raises(FileNotFoundError, match="File does not exist"):
evoked.plot(noise_cov="nonexistent-cov.fif", time_unit="s")
raw = read_raw_fif(raw_sss_fname)
events = make_fixed_length_events(raw)
epochs = Epochs(raw, events, picks=default_picks)
cov = compute_covariance(epochs)
evoked_sss = epochs.average()
with _record_warnings(), pytest.warns(RuntimeWarning, match="relative scaling"):
evoked_sss.plot(noise_cov=cov, time_unit="s")
plt.close("all")
@pytest.mark.slowtest
def test_plot_evoked():
"""Test evoked.plot."""
epochs = _get_epochs()
evoked = epochs.average()
assert evoked.proj is False
fig = evoked.plot(
proj=True, hline=[1], exclude=[], window_title="foo", time_unit="s"
)
amplitudes = _get_amplitudes(fig)
assert len(amplitudes) == len(default_picks)
assert evoked.proj is False
assert evoked.info["bads"] == ["MEG 2641", "EEG 004"]
eeg_lines = fig.axes[2].lines
n_eeg = sum(ch_type == "eeg" for ch_type in evoked.get_channel_types())
assert len(eeg_lines) == n_eeg == 4
n_bad = sum(same_color(line.get_color(), "0.5") for line in eeg_lines)
assert n_bad == 1
# Test a click
ax = fig.get_axes()[0]
line = ax.lines[0]
_fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], "data")
_fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], "data")
# plot with bad channels excluded & spatial_colors & zorder
evoked.plot(exclude="bads", time_unit="s")
# test selective updating of dict keys is working.
evoked.plot(hline=[1], units=dict(mag="femto foo"), time_unit="s")
evoked_delayed_ssp = _get_epochs_delayed_ssp().average()
evoked_delayed_ssp.plot(proj="interactive", time_unit="s")
evoked_delayed_ssp.apply_proj()
pytest.raises(
RuntimeError, evoked_delayed_ssp.plot, proj="interactive", time_unit="s"
)
with evoked_delayed_ssp.info._unlock():
evoked_delayed_ssp.info["projs"] = []
pytest.raises(
RuntimeError, evoked_delayed_ssp.plot, proj="interactive", time_unit="s"
)
pytest.raises(
RuntimeError,
evoked_delayed_ssp.plot,
proj="interactive",
axes="foo",
time_unit="s",
)
plt.close("all")
# test `gfp='only'`: GFP (EEG) and RMS (MEG)
fig, ax = plt.subplots(3)
evoked.plot(gfp="only", time_unit="s", axes=ax)
assert len(ax[0].lines) == len(ax[1].lines) == len(ax[2].lines) == 1
assert ax[0].get_title() == "EEG (3 channels)"
assert ax[0].texts[0].get_text() == "GFP"
assert ax[1].get_title() == "Gradiometers (9 channels)"
assert ax[1].texts[0].get_text() == "RMS"
assert ax[2].get_title() == "Magnetometers (2 channels)"
assert ax[1].texts[0].get_text() == "RMS"
plt.close("all")
# Test invalid `gfp`
with pytest.raises(ValueError):
evoked.plot(gfp="foo", time_unit="s")
# plot with bad channels excluded, spatial_colors, zorder & pos. layout
evoked.rename_channels({"MEG 0133": "MEG 0000"})
evoked.plot(
exclude=evoked.info["bads"],
spatial_colors=True,
gfp=True,
zorder="std",
time_unit="s",
)
evoked.plot(exclude=[], spatial_colors=True, zorder="unsorted", time_unit="s")
pytest.raises(TypeError, evoked.plot, zorder="asdf", time_unit="s")
plt.close("all")
evoked.plot_sensors() # Test plot_sensors
plt.close("all")
evoked.pick(evoked.ch_names[:4])
with catch_logging() as log_file:
evoked.plot(verbose=True, time_unit="s")
assert "Need more than one" in log_file.getvalue()
# Test highlight
for highlight in [(0, 0.1), [(0, 0.1), (0.1, 0.2)]]:
fig = evoked.plot(time_unit="s", highlight=highlight)
regular_axes = [ax for ax in fig.axes if not isinstance(ax, HostAxes)]
for ax in regular_axes:
highlighted_areas = [
child
for child in ax.get_children()
if isinstance(child, PolyCollection)
]
assert len(highlighted_areas) == len(np.atleast_2d(highlight))
with pytest.raises(ValueError, match="must be reshapable into a 2D array"):
fig = evoked.plot(time_unit="s", highlight=0.1)
# set one channel location to nan, confirm spatial_colors still works
evoked = _get_epochs().load_data().average("grad") # reload data
evoked.info["chs"][0]["loc"][:] = np.nan
fig = evoked.plot(time_unit="s", spatial_colors=True)
line_clr = [x.get_color() for x in fig.axes[0].get_lines()]
assert not np.all(np.isnan(line_clr) & (line_clr == 0))
def test_constrained_layout():
"""Test that we handle constrained layouts correctly."""
fig, ax = plt.subplots(1, 1, layout="constrained")
assert fig.get_constrained_layout()
evoked = mne.read_evokeds(evoked_fname)[0]
evoked.pick(evoked.ch_names[:2])
# smoke test that it does not break things
evoked.plot(axes=ax)
assert fig.get_constrained_layout()
plt.close("all")
def _get_amplitudes(fig):
# ignore the spatial_colors parasite axes
regular_axes = [ax for ax in fig.axes if not isinstance(ax, HostAxes)]
amplitudes = [line.get_ydata() for ax in regular_axes for line in ax.get_lines()]
# this will exclude hlines, which are lists not arrays
amplitudes = np.array([line for line in amplitudes if isinstance(line, np.ndarray)])
return amplitudes
@pytest.mark.parametrize(
"picks, rlims, avg_proj",
[
(default_picks[:-4], (0.59, 0.61), False), # MEG
(np.arange(340, 360), (0.56, 0.57), True), # EEG
(np.arange(340, 360), (0.79, 0.81), False), # EEG
],
)
def test_plot_evoked_reconstruct(picks, rlims, avg_proj):
"""Test proj="reconstruct"."""
evoked = _get_epochs(picks=picks).average()
if avg_proj:
evoked.set_eeg_reference(projection=True).apply_proj()
assert len(evoked.info["projs"]) == 1
assert evoked.proj is True
else:
assert len(evoked.info["projs"]) == 0
assert evoked.proj is False
fig = evoked.plot(
proj=True, hline=[1], exclude=[], window_title="foo", time_unit="s"
)
amplitudes = _get_amplitudes(fig)
assert len(amplitudes) == len(picks)
assert evoked.proj is avg_proj
fig = evoked.plot(proj="reconstruct", exclude=[])
amplitudes_recon = _get_amplitudes(fig)
if avg_proj is False:
assert_allclose(amplitudes, amplitudes_recon)
proj = compute_proj_evoked(evoked.copy().crop(None, 0).apply_proj())
evoked.add_proj(proj)
assert len(evoked.info["projs"]) == 2 if len(picks) == 3 else 4
fig = evoked.plot(proj=True, exclude=[])
amplitudes_proj = _get_amplitudes(fig)
fig = evoked.plot(proj="reconstruct", exclude=[])
amplitudes_recon = _get_amplitudes(fig)
assert len(amplitudes_recon) == len(picks)
norm = np.linalg.norm(amplitudes)
norm_proj = np.linalg.norm(amplitudes_proj)
norm_recon = np.linalg.norm(amplitudes_recon)
r = np.dot(amplitudes_recon.ravel(), amplitudes.ravel()) / (norm_recon * norm)
assert rlims[0] < r < rlims[1]
assert 1.05 * norm_proj < norm_recon
if not avg_proj:
assert norm_proj < norm * 0.9
cov = read_cov(cov_fname)
with pytest.raises(ValueError, match='Cannot use proj="reconstruct"'):
evoked.plot(noise_cov=cov, proj="reconstruct")
plt.close("all")
def test_plot_evoked_image():
"""Test plot_evoked_image."""
evoked = _get_epochs().average()
evoked.plot_image(proj=True, time_unit="ms")
# fail nicely on NaN
evoked_nan = evoked.copy()
evoked_nan.data[:, 0] = np.nan
pytest.raises(ValueError, evoked_nan.plot)
with np.errstate(invalid="ignore"):
pytest.raises(ValueError, evoked_nan.plot_image)
pytest.raises(ValueError, evoked_nan.plot_joint)
# test mask
evoked.plot_image(picks=[1, 2], mask=evoked.data > 0, time_unit="s")
evoked.plot_image(
picks=[1, 2],
mask_cmap=None,
colorbar=False,
mask=np.ones(evoked.data.shape).astype(bool),
time_unit="s",
)
with _record_warnings(), pytest.warns(RuntimeWarning, match="not adding contour"):
evoked.plot_image(picks=[1, 2], mask=None, mask_style="both", time_unit="s")
with pytest.raises(ValueError, match="must have the same shape"):
evoked.plot_image(mask=evoked.data[1:, 1:] > 0, time_unit="s")
# plot with bad channels excluded
evoked.plot_image(exclude="bads", cmap="interactive", time_unit="s")
plt.close("all")
with pytest.raises(ValueError, match="not unique"):
evoked.plot_image(picks=[0, 0], time_unit="s") # duplicates
ch_names = evoked.ch_names[3:5]
picks = [evoked.ch_names.index(ch) for ch in ch_names]
fig = evoked.plot_image(show_names="all", time_unit="s", picks=picks)
fig.canvas.draw_idle()
yticklabels = fig.axes[0].get_yticklabels()
assert len(yticklabels) == len(ch_names)
for tick_target, tick_observed in zip(ch_names, yticklabels):
assert tick_target in str(tick_observed)
evoked.plot_image(show_names=True, time_unit="s")
# test groupby
evoked.plot_image(group_by=dict(sel=sel), axes=dict(sel=plt.axes()))
plt.close("all")
for group_by, axes in (("something", dict()), (dict(), "something")):
pytest.raises(ValueError, evoked.plot_image, group_by=group_by, axes=axes)
with pytest.raises(ValueError, match="`clim` must be a dict."):
evoked.plot_image(clim=[-4, 4])
def test_plot_white():
"""Test plot_white."""
cov = read_cov(cov_fname)
cov["method"] = "empirical"
cov["projs"] = [] # avoid warnings
evoked = _get_epochs().average()
evoked.set_eeg_reference("average") # Avoid warnings
# test rank param.
with pytest.raises(ValueError, match="exceeds"):
evoked.plot_white(cov, rank={"mag": 10})
evoked.plot_white(cov, rank={"mag": 1, "grad": 8, "eeg": 2}, time_unit="s")
fig = evoked.plot_white(cov, rank={"mag": 1}, time_unit="s") # test rank
evoked.plot_white(cov, rank={"grad": 8}, time_unit="s", axes=fig.axes[:4])
with pytest.raises(ValueError, match=r"must have shape \(4,\), got \(2,"):
evoked.plot_white(cov, axes=fig.axes[:2])
with pytest.raises(ValueError, match="When not using SSS"):
evoked.plot_white(cov, rank={"meg": 306})
evoked.plot_white([cov, cov], time_unit="s")
plt.close("all")
fig = plot_evoked_white(evoked, [cov, cov])
assert len(fig.axes) == 3 * 3
axes = np.array(fig.axes[:6]).reshape(3, 2)
plot_evoked_white(evoked, [cov, cov], axes=axes)
with pytest.raises(ValueError, match=r"have shape \(3, 2\), got"):
plot_evoked_white(evoked, [cov, cov], axes=axes[:, :1])
# Hack to test plotting of maxfiltered data
evoked_sss = _get_epochs(picks=("meg", "eeg")).average()
evoked_sss.set_eeg_reference(projection=True).apply_proj()
sss = dict(sss_info=dict(in_order=80, components=np.arange(80)))
with evoked_sss.info._unlock():
evoked_sss.info["proc_history"] = [dict(max_info=sss)]
evoked_sss.plot_white([cov, cov], rank={"meg": 64})
with pytest.raises(ValueError, match="When using SSS"):
evoked_sss.plot_white(cov, rank={"grad": 201}, verbose="error")
evoked_sss.plot_white(cov, rank={"meg": 302}, time_unit="s")
@pytest.mark.parametrize(
"combine,vlines,title,picks",
(
pytest.param(None, [0.1, 0.2], "MEG 0113", "MEG 0113", id="singlepick"),
pytest.param("mean", [], "(mean)", "mag", id="mag-mean"),
pytest.param("gfp", "auto", "(GFP)", "eeg", id="eeg-gfp"),
pytest.param(None, "auto", "(RMS)", ["MEG 0113", "MEG 0112"], id="meg-rms"),
pytest.param(
"std", "auto", "(std. dev.)", ["MEG 0113", "MEG 0112"], id="meg-std"
),
pytest.param(
lambda x: np.min(x, axis=1), "auto", "MEG 0112", [0, 1], id="intpicks"
),
),
)
def test_plot_compare_evokeds_title(evoked, picks, vlines, combine, title):
"""Test title generation by plot_compare_evokeds()."""
# test picks, combine, and vlines (1-channel pick also shows sensor inset)
fig = plot_compare_evokeds(evoked, picks=picks, vlines=vlines, combine=combine)
assert fig[0].axes[0].get_title().endswith(title)
@pytest.mark.slowtest # slow on Azure
def test_plot_compare_evokeds(evoked):
"""Test plot_compare_evokeds."""
# test defaults
figs = plot_compare_evokeds(evoked)
assert len(figs) == 3
# test passing more than one evoked
red, blue = evoked.copy(), evoked.copy()
red.comment = red.comment + "*" * 100
red.data *= 1.5
blue.data /= 1.5
evoked_dict = {"aud/l": blue, "aud/r": red, "vis": evoked}
huge_dict = {f"cond{i}": ev for i, ev in enumerate([evoked] * 11)}
plot_compare_evokeds(evoked_dict) # dict
plot_compare_evokeds([[red, evoked], [blue, evoked]]) # list of lists
figs = plot_compare_evokeds({"cond": [blue, red, evoked]}) # dict of list
# test that confidence bands are plausible
for fig in figs:
extents = fig.axes[0].collections[0].get_paths()[0].get_extents()
xlim, ylim = extents.get_points().T
assert np.allclose(xlim, evoked.times[[0, -1]])
line = fig.axes[0].lines[0]
xvals = line.get_xdata()
assert np.allclose(xvals, evoked.times)
yvals = line.get_ydata()
assert (yvals < ylim[1]).all()
assert (yvals > ylim[0]).all()
# test plotting eyetracking data
plt.close("all") # close the previous figures as to avoid a too many figs warning
info_tmp = mne.create_info(["pupil_left"], evoked.info["sfreq"], ["pupil"])
evoked_et = mne.EvokedArray(np.ones_like(evoked.times).reshape(1, -1), info_tmp)
figs = plot_compare_evokeds(evoked_et, show_sensors=False)
assert len(figs) == 1
# test plotting only invalid channel types
info_tmp = mne.create_info(["ias"], evoked.info["sfreq"], ["ias"])
ev_invalid = mne.EvokedArray(np.ones_like(evoked.times).reshape(1, -1), info_tmp)
with pytest.raises(RuntimeError, match="No valid"):
plot_compare_evokeds(ev_invalid, picks="all")
plt.close("all")
# test other CI args
def ci_func(array):
return array.mean(axis=0, keepdims=True) * np.array([[0.5], [1.5]])
ci_types = (None, False, 0.5, ci_func)
for _ci in ci_types:
fig = plot_compare_evokeds({"cond": [blue, red, evoked]}, ci=_ci)[0]
if _ci in ci_types[2:]:
assert np.any(
[isinstance(coll, PolyCollection) for coll in fig.axes[0].collections]
)
# make sure we can get a CI even for single conditions
fig = plot_compare_evokeds(evoked, picks="eeg", ci=ci_func)[0]
assert np.any(
[isinstance(coll, PolyCollection) for coll in fig.axes[0].collections]
)
with pytest.raises(TypeError, match='"ci" must be None, bool, float or'):
plot_compare_evokeds(evoked, ci="foo")
# test sensor inset, legend location, and axis inversion & truncation
plot_compare_evokeds(
evoked_dict,
invert_y=True,
legend="upper left",
show_sensors="center",
truncate_xaxis=False,
truncate_yaxis=False,
)
plot_compare_evokeds(evoked, ylim=dict(mag=(-50, 50)), truncate_yaxis=True)
plt.close("all")
# test styles
plot_compare_evokeds(
evoked_dict,
colors=["b", "r", "g"],
linestyles=[":", "-", "--"],
split_legend=True,
)
style_dict = dict(aud=dict(alpha=0.3), vis=dict(linewidth=3, c="k"))
plot_compare_evokeds(
evoked_dict,
styles=style_dict,
colors={"aud/r": "r"},
linestyles=dict(vis="dotted"),
ci=False,
)
plot_compare_evokeds(evoked_dict, colors=list(range(3)))
plt.close("all")
# test colormap
cmap = _get_cmap("viridis")
plot_compare_evokeds(evoked_dict, cmap=cmap, colors=dict(aud=0.4, vis=0.9))
plot_compare_evokeds(evoked_dict, cmap=cmap, colors=dict(aud=1, vis=2))
plot_compare_evokeds(
evoked_dict, cmap=("cmap title", "inferno"), linestyles=["-", ":", "--"]
)
plt.close("all")
# test combine
match = "combine must be an instance of None, callable, or str"
with pytest.raises(TypeError, match=match):
plot_compare_evokeds(evoked, combine=["mean", "gfp"])
plt.close("all")
# test warnings
with pytest.warns(RuntimeWarning, match='in "picks"; cannot combine'):
plot_compare_evokeds(evoked, picks=[0], combine="median")
plt.close("all")
# test errors
with pytest.raises(TypeError, match='"evokeds" must be a dict, list'):
plot_compare_evokeds("foo")
with pytest.raises(ValueError, match=r'keys in "styles" \(.*\) must '):
plot_compare_evokeds(evoked_dict, styles=dict(foo="foo", bar="bar"))
with pytest.raises(ValueError, match="colors in the default color cycle"):
plot_compare_evokeds(huge_dict, colors=None)
with pytest.raises(TypeError, match='"cmap" is specified, then "colors"'):
plot_compare_evokeds(
evoked_dict,
cmap="Reds",
colors={"aud/l": "foo", "aud/r": "bar", "vis": "baz"},
)
plt.close("all")
for kwargs in [dict(colors=[0, 1]), dict(linestyles=["-", ":"])]:
match = r"but there are only \d* (colors|linestyles). Please specify"
with pytest.raises(ValueError, match=match):
plot_compare_evokeds(evoked_dict, **kwargs)
for kwargs in [dict(colors="foo"), dict(linestyles="foo")]:
match = r'"(colors|linestyles)" must be a dict, list, or None; got '
with pytest.raises(TypeError, match=match):
plot_compare_evokeds(evoked_dict, **kwargs)
for kwargs in [dict(colors=dict(foo="f")), dict(linestyles=dict(foo="f"))]:
match = r'If "(colors|linestyles)" is a dict its keys \(.*\) must '
with pytest.raises(ValueError, match=match):
plot_compare_evokeds(evoked_dict, **kwargs)
for kwargs in [dict(legend="foo"), dict(show_sensors="foo")]:
with pytest.raises(ValueError, match="not a legal MPL loc, please"):
plot_compare_evokeds(evoked_dict, **kwargs)
with pytest.raises(TypeError, match="an instance of list or tuple"):
plot_compare_evokeds(evoked_dict, vlines="foo")
with pytest.raises(ValueError, match='"truncate_yaxis" must be bool or '):
plot_compare_evokeds(evoked_dict, truncate_yaxis="foo")
plt.close("all")
# test axes='topo'
figs = plot_compare_evokeds(evoked_dict, axes="topo", legend=True)
for fig in figs:
assert len(fig.axes[0].lines) == len(evoked_dict)
# test with (fake) CSD data
csd = _get_epochs(picks=np.arange(315, 320)).average() # 5 EEG chs
for entry in csd.info["chs"]:
entry["coil_type"] = FIFF.FIFFV_COIL_EEG_CSD
entry["unit"] = FIFF.FIFF_UNIT_V_M2
plot_compare_evokeds(csd, picks="csd", axes="topo")
# old tests
red.info["chs"][0]["loc"][:2] = 0 # test plotting channel at zero
(fig,) = plot_compare_evokeds(
[red, blue], picks=[0], ci=lambda x: [x.std(axis=0), -x.std(axis=0)]
)
# reasonable legend lengths
leg_texts = [t.get_text() for t in fig.axes[0].get_legend().get_texts()]
assert all(len(lt) < 50 for lt in leg_texts)
plot_compare_evokeds([list(evoked_dict.values())], picks=[0], ci=_parametric_ci)
# smoke test for tmin >= 0 (from mailing list)
red.crop(0.01, None)
assert len(red.times) > 2
plot_compare_evokeds(red)
# plot a flat channel
red.data = np.zeros_like(red.data)
plot_compare_evokeds(red)
# smoke test for one time point (not useful but should not fail)
red.crop(0.02, 0.02)
assert len(red.times) == 1
plot_compare_evokeds(red)
# now that we've cropped `red`:
with pytest.raises(ValueError, match="not contain the same time instants"):
plot_compare_evokeds(evoked_dict)
plt.close("all")
def test_plot_compare_evokeds_neuromag122():
"""Test topomap plotting."""
evoked = mne.read_evokeds(evoked_fname, "Left Auditory", baseline=(None, 0))
evoked.pick(picks="grad")
evoked.pick(evoked.ch_names[:122])
ch_names = [f"MEG {k:03}" for k in range(1, 123)]
for c in evoked.info["chs"]:
c["coil_type"] = FIFF.FIFFV_COIL_NM_122
evoked.rename_channels(
{c_old: c_new for (c_old, c_new) in zip(evoked.ch_names, ch_names)}
)
mne.viz.plot_compare_evokeds([evoked, evoked])
@testing.requires_testing_data
def test_plot_ctf():
"""Test plotting of CTF evoked."""
raw = mne.io.read_raw_ctf(ctf_fname, preload=True)
events = np.array([[200, 0, 1]])
event_id = 1
tmin, tmax = -0.1, 0.5 # start and end of an epoch in s.
picks = mne.pick_types(
raw.info, meg=True, stim=True, eog=True, ref_meg=True, exclude="bads"
)[::20]
epochs = mne.Epochs(
raw,
events,
event_id,
tmin,
tmax,
proj=True,
picks=picks,
preload=True,
decim=10,
verbose="error",
)
evoked = epochs.average()
evoked.plot_joint(times=[0.1])
# test plotting with invalid ylim argument
with pytest.raises(TypeError, match="ylim must be an instance of dict or None"):
evoked.plot_joint(times=[0.1], ts_args=dict(ylim=(-10, 10)))
mne.viz.plot_compare_evokeds([evoked, evoked])
# make sure axes position is "almost" unchanged
# when axes were passed to plot_joint by the user
times = [0.1, 0.2, 0.3]
fig = plt.figure()
# create custom axes for topomaps, colorbar and the timeseries
gs = gridspec.GridSpec(3, 7, hspace=0.5, top=0.8, figure=fig)
topo_axes = [
fig.add_subplot(gs[0, idx * 2 : (idx + 1) * 2]) for idx in range(len(times))
]
topo_axes.append(fig.add_subplot(gs[0, -1]))
ts_axis = fig.add_subplot(gs[1:, 1:-1])
def get_axes_midpoints(axes):
midpoints = list()
for ax in axes[:-1]:
pos = ax.get_position()
midpoints.append([pos.x0 + (pos.width * 0.5), pos.y0 + (pos.height * 0.5)])
return np.array(midpoints)
midpoints_before = get_axes_midpoints(topo_axes)
evoked.plot_joint(
times=times,
ts_args={"axes": ts_axis},
topomap_args={"axes": topo_axes},
title=None,
)
midpoints_after = get_axes_midpoints(topo_axes)
assert (np.linalg.norm(midpoints_before - midpoints_after) < 0.1).all()