# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
from mne import (
Annotations,
Epochs,
make_fixed_length_events,
pick_types,
read_cov,
read_events,
)
from mne.io import read_raw_fif
from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs
from mne.utils import _record_warnings, catch_logging
from mne.viz.ica import _create_properties_layout, plot_ica_properties
from mne.viz.utils import _fake_click, _fake_keypress
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
evoked_fname = base_dir / "test-ave.fif"
raw_fname = base_dir / "test_raw.fif"
cov_fname = base_dir / "test-cov.fif"
event_name = base_dir / "test-eve.fif"
event_id, tmin, tmax = 1, -0.1, 0.2
raw_ctf_fname = base_dir / "test_ctf_raw.fif"
pytest.importorskip("sklearn")
def _get_raw(preload=False):
"""Get raw data."""
return read_raw_fif(raw_fname, preload=preload)
def _get_events():
"""Get events."""
return read_events(event_name)
def _get_picks(raw):
"""Get picks."""
return [0, 1, 2, 6, 7, 8, 12, 13, 14] # take a only few channels
def _get_epochs():
"""Get epochs."""
raw = _get_raw()
events = _get_events()
picks = _get_picks(raw)
with pytest.warns(RuntimeWarning, match="projection"):
epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks)
return epochs
def test_plot_ica_components():
"""Test plotting of ICA solutions."""
res = 8
fast_test = {"res": res, "contours": 0, "sensors": False}
raw = _get_raw()
ica = ICA(noise_cov=read_cov(cov_fname), n_components=8)
ica_picks = _get_picks(raw)
with pytest.warns(RuntimeWarning, match="(projection)|(unstable mixing matrix)"):
ica.fit(raw, picks=ica_picks)
for components in [0, [0], [0, 1], [0, 1] * 2, None]:
ica.plot_components(
components, image_interp="cubic", colorbar=True, **fast_test
)
plt.close("all")
# test interactive mode (passing 'inst' arg)
with catch_logging() as log:
ica.plot_components(
[0, 1],
image_interp="cubic",
inst=raw,
res=16,
verbose="debug",
ch_type="grad",
)
log = log.getvalue()
assert "grad data" in log
assert "extrapolation mode local to mean" in log
fig = plt.gcf()
# test title click
# ----------------
lbl = fig.axes[1].get_label()
ica_idx = int(lbl[-3:])
titles = [ax.title for ax in fig.axes]
title_pos_midpoint = (
titles[1].get_window_extent().extents.reshape((2, 2)).mean(axis=0)
)
# first click adds to exclude
_fake_click(fig, fig.axes[1], title_pos_midpoint, xform="pix")
assert ica_idx in ica.exclude
# clicking again removes from exclude
_fake_click(fig, fig.axes[1], title_pos_midpoint, xform="pix")
assert ica_idx not in ica.exclude
# test topo click
# ---------------
_fake_click(fig, fig.axes[1], (0.0, 0.0), xform="data")
c_fig = plt.gcf()
labels = [ax.get_label() for ax in c_fig.axes]
for label in ["topomap", "image", "erp", "spectrum", "variance"]:
assert label in labels
topomap_ax = c_fig.axes[labels.index("topomap")]
title = topomap_ax.get_title()
assert lbl.split(" ")[0] == title.split(" ")[0]
# test provided axes
_, ax = plt.subplots(1, 1)
ica.plot_components(axes=ax, picks=0, **fast_test)
_, ax = plt.subplots(2, 1)
ica.plot_components(axes=ax, picks=[0, 1], **fast_test)
_, ax = plt.subplots(2, 2)
ica.plot_components(axes=ax, picks=[0, 1, 2, 3], **fast_test)
_, ax = plt.subplots(3, 2)
ica.plot_components(
axes=ax, picks=[0, 1, 2, 3, 4, 5], nrows=2, ncols=2, **fast_test
)
ica.info = None
with pytest.raises(RuntimeError, match="fit the ICA"):
ica.plot_components(1, ch_type="mag")
@pytest.mark.slowtest
def test_plot_ica_properties():
"""Test plotting of ICA properties."""
raw = _get_raw(preload=True).crop(0, 5)
raw.add_proj([], remove_existing=True)
with raw.info._unlock():
raw.info["highpass"] = 1.0 # fake high-pass filtering
events = make_fixed_length_events(raw)
picks = _get_picks(raw)[:6]
pick_names = [raw.ch_names[k] for k in picks]
raw.pick(pick_names)
reject = dict(grad=4000e-13, mag=4e-12)
epochs = Epochs(
raw, events[:3], event_id, tmin, tmax, baseline=(None, 0), preload=True
)
ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, max_iter=1, random_state=0)
with _record_warnings(), pytest.warns(RuntimeWarning, match="projection"):
ica.fit(raw)
# test _create_properties_layout
fig, ax = _create_properties_layout()
assert_equal(len(ax), 5)
with pytest.raises(ValueError, match="specify both fig and figsize"):
_create_properties_layout(figsize=(2, 2), fig=fig)
topoargs = dict(topomap_args={"res": 4, "contours": 0, "sensors": False})
with catch_logging() as log:
ica.plot_properties(raw, picks=0, verbose="debug", **topoargs)
log = log.getvalue()
assert raw.ch_names[0] == "MEG 0113"
assert "extrapolation mode local to mean" in log, log
ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
fig = ica.plot_properties(
epochs,
picks=1,
image_args={"sigma": 1.5},
topomap_args=dict(res=4, colorbar=True),
psd_args={"fmax": 65.0},
plot_std=False,
log_scale=True,
figsize=[4.5, 4.5],
reject=reject,
)[0]
# test keypresses
ax_labels = [ax.get_label() for ax in fig.axes]
# test topomap change type
ax = fig.axes[ax_labels.index("topomap")]
assert ax.get_title() == "ICA001 (mag)"
_fake_keypress(fig, "t")
assert ax.get_title() == "ICA001 (grad)"
_fake_keypress(fig, "t")
assert ax.get_title() == "ICA001 (mag)"
# test log scale
ax = fig.axes[ax_labels.index("spectrum")]
assert ax.get_xscale() == "log"
_fake_keypress(fig, "l")
assert ax.get_xscale() == "linear"
_fake_keypress(fig, "l")
assert ax.get_xscale() == "log"
plt.close("all")
with pytest.raises(TypeError, match="must be an instance"):
ica.plot_properties(epochs, dB=list("abc"))
with pytest.raises(TypeError, match="must be an instance"):
ica.plot_properties(ica)
with pytest.raises(TypeError, match="must be an instance"):
ica.plot_properties([0.2])
with pytest.raises(TypeError, match="must be an instance"):
plot_ica_properties(epochs, epochs)
with pytest.raises(TypeError, match="must be an instance"):
ica.plot_properties(epochs, psd_args="not dict")
with pytest.raises(TypeError, match="must be an instance"):
ica.plot_properties(epochs, plot_std=[])
fig, ax = plt.subplots(2, 3)
ax = ax.ravel()[:-1]
ica.plot_properties(epochs, picks=1, axes=ax, **topoargs)
pytest.raises(TypeError, plot_ica_properties, epochs, ica, picks=[0, 1], axes=ax)
pytest.raises(ValueError, ica.plot_properties, epochs, axes="not axes")
plt.close("all")
# Test merging grads.
pick_names = raw.ch_names[:15:2] + raw.ch_names[1:15:2]
raw = _get_raw(preload=True).pick(pick_names)
raw.crop(0, 5)
raw.info.normalize_proj()
ica = ICA(random_state=0, max_iter=1)
with pytest.warns(UserWarning, match="did not converge"):
ica.fit(raw)
ica.plot_properties(raw)
plt.close("all")
# Test handling of zeros
ica = ICA(random_state=0, max_iter=1)
epochs.pick(pick_names)
with _record_warnings(), pytest.warns(UserWarning, match="did not converge"):
ica.fit(epochs)
epochs._data[0] = 0
# Usually UserWarning: Infinite value .* for epo
with _record_warnings():
ica.plot_properties(epochs, **topoargs)
plt.close("all")
# Test Raw with annotations
annot = Annotations(onset=[1], duration=[1], description=["BAD"])
raw_annot = _get_raw(preload=True).set_annotations(annot).crop(0, 8)
raw_annot.pick(np.arange(10))
raw_annot.del_proj()
with _record_warnings(), pytest.warns(UserWarning, match="did not converge"):
ica.fit(raw_annot)
# drop bad data segments
fig = ica.plot_properties(raw_annot, picks=[0, 1], **topoargs)
assert_equal(len(fig), 2)
# don't drop
ica.plot_properties(raw_annot, reject_by_annotation=False, **topoargs)
def test_plot_ica_sources(raw_orig, browser_backend, monkeypatch):
"""Test plotting of ICA panel."""
raw = raw_orig.copy().crop(0, 1)
picks = _get_picks(raw)
epochs = _get_epochs()
raw.pick([raw.ch_names[k] for k in picks])
ica_picks = pick_types(
raw.info, meg=True, eeg=False, stim=False, ecg=False, eog=False, exclude="bads"
)
ica = ICA(n_components=2)
ica.fit(raw, picks=ica_picks)
ica.exclude = [1]
if sys.platform == "darwin": # unknown transformation bug
monkeypatch.setenv("MNE_BROWSE_RAW_SIZE", "20,20")
fig = ica.plot_sources(raw)
assert browser_backend._get_n_figs() == 1
# change which component is in ICA.exclude (click data trace to remove
# current one; click name to add other one)
fig._redraw()
assert_array_equal(ica.exclude, [1])
assert fig.mne.info["bads"] == [ica._ica_names[1]]
x = fig.mne.traces[1].get_xdata()[5]
y = fig.mne.traces[1].get_ydata()[5]
fig._fake_click((x, y), xform="data") # exclude = []
assert fig.mne.info["bads"] == []
assert_array_equal(ica.exclude, [1]) # unchanged
fig._click_ch_name(ch_index=0, button=1) # exclude = [0]
assert fig.mne.info["bads"] == [ica._ica_names[0]]
assert_array_equal(ica.exclude, [1])
fig._fake_keypress(fig.mne.close_key)
fig._close_event()
assert browser_backend._get_n_figs() == 0
assert_array_equal(ica.exclude, [0])
# test when picks does not include ica.exclude.
ica.plot_sources(raw, picks=[1])
assert browser_backend._get_n_figs() == 1
browser_backend._close_all()
# dtype can change int->np.int64 after load, test it explicitly
ica.n_components_ = np.int64(ica.n_components_)
# test clicks on y-label (need >2 secs for plot_properties() to work)
long_raw = raw_orig.crop(0, 5)
fig = ica.plot_sources(long_raw)
assert browser_backend._get_n_figs() == 1
fig._redraw()
fig._click_ch_name(ch_index=0, button=3)
assert len(fig.mne.child_figs) == 1
assert browser_backend._get_n_figs() == 2
# close child fig directly (workaround for mpl issue #18609)
fig._fake_keypress("escape", fig=fig.mne.child_figs[0])
assert browser_backend._get_n_figs() == 1
fig._fake_keypress(fig.mne.close_key)
assert browser_backend._get_n_figs() == 0
del long_raw
# test with annotations and a measurement date
orig_annot = raw.annotations
raw.set_annotations(Annotations([0.2], [0.1], "Test"))
fig = ica.plot_sources(raw)
if browser_backend.name == "matplotlib":
assert len(fig.mne.ax_main.collections) == 1
assert len(fig.mne.ax_hscroll.collections) == 1
else:
assert len(fig.mne.regions) == 1
assert_allclose(fig.mne.regions[0].getRegion(), (0.2, 0.3))
# test with annotations and no measurement date
orig_meas_date = raw.info["meas_date"]
raw.set_meas_date(None)
assert raw.first_samp != 0
raw.set_annotations(Annotations([0.2], [0.1], "Test"))
fig = ica.plot_sources(raw)
if browser_backend.name == "matplotlib":
assert len(fig.mne.ax_main.collections) == 1
assert len(fig.mne.ax_hscroll.collections) == 1
else:
assert len(fig.mne.regions) == 1
assert_allclose(fig.mne.regions[0].getRegion(), (0.2, 0.3))
raw.set_meas_date(orig_meas_date)
raw.set_annotations(orig_annot)
# test error handling
raw_ = raw.copy().load_data()
raw_.drop_channels("MEG 0113")
with pytest.raises(ValueError, match="could not be picked"):
ica.plot_sources(inst=raw_)
epochs_ = epochs.copy().load_data()
epochs_.drop_channels("MEG 0113")
with pytest.raises(ValueError, match="could not be picked"):
ica.plot_sources(inst=epochs_)
del raw_
del epochs_
# test w/ epochs and evokeds
ica.plot_sources(epochs)
ica.plot_sources(epochs.average())
evoked = epochs.average()
ica.exclude = [0]
fig = ica.plot_sources(evoked)
# Test a click
ax = fig.get_axes()[0]
line = ax.lines[0]
_fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], "data")
_fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], "data")
leg = ax.get_legend()
assert len(leg.get_texts()) == len(ica.exclude) == 1
# test passing psd_args argument
ica.plot_sources(epochs, psd_args=dict(fmax=50))
# plot with bad channels excluded
ica.exclude = [0]
ica.plot_sources(evoked)
# regression test for `IndexError` when passing non-consecutive picks or consecutive
# picks not including `0` (https://github.com/mne-tools/mne-python/pull/11808)
ica.plot_sources(evoked, picks=1)
# pretend find_bads_eog() yielded some results
ica.labels_ = {"eog": [0], "eog/0/crazy-channel": [0]}
ica.plot_sources(evoked) # now with labels
# pass an invalid inst
with pytest.raises(ValueError, match="must be of Raw or Epochs type"):
ica.plot_sources("meeow")
@pytest.mark.slowtest
def test_plot_ica_overlay():
"""Test plotting of ICA cleaning."""
raw = _get_raw(preload=True)
with raw.info._unlock():
raw.info["highpass"] = 1.0 # fake high-pass filtering
picks = _get_picks(raw)
ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, random_state=0)
# overlay plotting requires a fitted ICA
with pytest.raises(RuntimeError, match="need to fit"):
ica.plot_overlay(inst=raw)
# can't use info.normalize_proj here because of how and when ICA and Epochs
# objects do picking of Raw data
with pytest.warns(RuntimeWarning, match="projection"):
ica.fit(raw, picks=picks)
# don't test raw, needs preload ...
with pytest.warns(RuntimeWarning, match="projection"):
ecg_epochs = create_ecg_epochs(raw, picks=picks)
ica.plot_overlay(ecg_epochs.average())
with pytest.warns(RuntimeWarning, match="projection"):
eog_epochs = create_eog_epochs(raw, picks=picks)
ica.plot_overlay(eog_epochs.average(), n_pca_components=2)
pytest.raises(TypeError, ica.plot_overlay, raw[:2, :3][0])
pytest.raises(TypeError, ica.plot_overlay, raw, exclude=2)
ica.plot_overlay(raw)
plt.close("all")
# smoke test for CTF
raw = read_raw_fif(raw_ctf_fname)
raw.apply_gradient_compensation(3)
with raw.info._unlock():
raw.info["highpass"] = 1.0 # fake high-pass filtering
picks = pick_types(raw.info, meg=True, ref_meg=False)
ica = ICA(
n_components=2,
)
ica.fit(raw, picks=picks)
with pytest.warns(RuntimeWarning, match="longer than"):
ecg_epochs = create_ecg_epochs(raw)
ica.plot_overlay(ecg_epochs.average())
def _get_geometry(fig):
try:
return fig.axes[0].get_subplotspec().get_geometry() # pragma: no cover
except AttributeError: # MPL < 3.4 (probably)
return fig.axes[0].get_geometry() # pragma: no cover
def test_plot_ica_scores():
"""Test plotting of ICA scores."""
raw = _get_raw()
picks = _get_picks(raw)
ica = ICA(noise_cov=read_cov(cov_fname), n_components=2)
with pytest.warns(RuntimeWarning, match="projection"):
ica.fit(raw, picks=picks)
ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], figsize=(6.4, 2.7))
ica.plot_scores([[0.3, 0.2], [0.3, 0.2]], axhline=[0.1, -0.1])
# check labels
ica.labels_ = dict()
ica.labels_["eog"] = 0
ica.labels_["ecg"] = 1
ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], labels="eog")
ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], labels="ecg")
ica.labels_["eog/0/foo"] = 0
ica.labels_["ecg/1/bar"] = 0
ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], labels="foo")
ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], labels="eog")
ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], labels="ecg")
# check setting number of columns
fig = ica.plot_scores([[0.3, 0.2], [0.3, 0.2], [0.3, 0.2]], axhline=[0.1, -0.1])
assert 2 == _get_geometry(fig)[1]
fig = ica.plot_scores([[0.3, 0.2], [0.3, 0.2]], axhline=[0.1, -0.1], n_cols=1)
assert 1 == _get_geometry(fig)[1]
# only use 1 column (even though 2 were requested)
fig = ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], n_cols=2)
assert 1 == _get_geometry(fig)[1]
with pytest.raises(ValueError, match="Need as many"):
ica.plot_scores([0.3, 0.2], axhline=[0.1, -0.1], labels=["one", "one-too-many"])
with pytest.raises(ValueError, match="The length of"):
ica.plot_scores([0.2])
def test_plot_instance_components(browser_backend):
"""Test plotting of components as instances of raw and epochs."""
raw = _get_raw()
picks = _get_picks(raw)
ica = ICA(noise_cov=read_cov(cov_fname), n_components=2)
with pytest.warns(RuntimeWarning, match="projection"):
ica.fit(raw, picks=picks)
ica.exclude = [0]
fig = ica.plot_sources(raw, title="Components")
keys = (
"home",
"home",
"end",
"down",
"up",
"right",
"left",
"-",
"+",
"=",
"d",
"d",
"pageup",
"pagedown",
"z",
"z",
"s",
"s",
"b",
)
for key in keys:
fig._fake_keypress(key)
x = fig.mne.traces[0].get_xdata()[0]
y = fig.mne.traces[0].get_ydata()[0]
fig._fake_click((x, y), xform="data")
fig._click_ch_name(ch_index=0, button=1)
fig._fake_keypress("escape")
browser_backend._close_all()
epochs = _get_epochs()
fig = ica.plot_sources(epochs, title="Components")
for key in keys:
fig._fake_keypress(key)
# Test a click
x = fig.mne.traces[0].get_xdata()[0]
y = fig.mne.traces[0].get_ydata()[0]
fig._fake_click((x, y), xform="data")
fig._click_ch_name(ch_index=0, button=1)
fig._fake_keypress("escape")