# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import itertools
from contextlib import nullcontext
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
from mne import (
Epochs,
Projection,
add_reference_channels,
create_info,
find_events,
make_forward_solution,
make_sphere_model,
pick_channels,
pick_channels_forward,
pick_types,
read_events,
read_evokeds,
set_bipolar_reference,
set_eeg_reference,
setup_volume_source_space,
)
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne._fiff.reference import _apply_reference
from mne.datasets import testing
from mne.epochs import BaseEpochs, make_fixed_length_epochs
from mne.io import RawArray, read_raw_fif
from mne.utils import _record_warnings, catch_logging
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
data_dir = testing.data_path(download=False) / "MEG" / "sample"
fif_fname = data_dir / "sample_audvis_trunc_raw.fif"
eve_fname = data_dir / "sample_audvis_trunc_raw-eve.fif"
ave_fname = data_dir / "sample_audvis-ave.fif"
def _test_reference(raw, reref, ref_data, ref_from):
"""Test whether a reference has been correctly applied."""
# Separate EEG channels from other channel types
picks_eeg = pick_types(raw.info, meg=False, eeg=True, exclude="bads")
picks_other = pick_types(
raw.info, meg=True, eeg=False, eog=True, stim=True, exclude="bads"
)
# Calculate indices of reference channesl
picks_ref = [raw.ch_names.index(ch) for ch in ref_from]
# Get data
_data = raw._data
_reref = reref._data
# Check that the ref has been properly computed
if ref_data is not None:
assert_array_equal(ref_data, _data[..., picks_ref, :].mean(-2))
# Get the raw EEG data and other channel data
raw_eeg_data = _data[..., picks_eeg, :]
raw_other_data = _data[..., picks_other, :]
# Get the rereferenced EEG data
reref_eeg_data = _reref[..., picks_eeg, :]
reref_other_data = _reref[..., picks_other, :]
# Check that non-EEG channels are untouched
assert_allclose(raw_other_data, reref_other_data, 1e-6, atol=1e-15)
# Undo rereferencing of EEG channels if possible
if ref_data is not None:
if isinstance(raw, BaseEpochs):
unref_eeg_data = reref_eeg_data + ref_data[:, np.newaxis, :]
else:
unref_eeg_data = reref_eeg_data + ref_data
assert_allclose(raw_eeg_data, unref_eeg_data, 1e-6, atol=1e-15)
@testing.requires_testing_data
def test_apply_reference():
"""Test base function for rereferencing."""
raw = read_raw_fif(fif_fname, preload=True)
# Rereference raw data by creating a copy of original data
reref, ref_data = _apply_reference(raw.copy(), ref_from=["EEG 001", "EEG 002"])
assert reref.info["custom_ref_applied"]
_test_reference(raw, reref, ref_data, ["EEG 001", "EEG 002"])
# The CAR reference projection should have been removed by the function
assert not _has_eeg_average_ref_proj(reref.info)
# Test that data is modified in place when copy=False
reref, ref_data = _apply_reference(raw, ["EEG 001", "EEG 002"])
assert raw is reref
# Test that disabling the reference does not change anything
reref, ref_data = _apply_reference(raw.copy(), [])
assert_array_equal(raw._data, reref._data)
# Test re-referencing Epochs object
raw = read_raw_fif(fif_fname, preload=False)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
epochs = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
picks=picks_eeg,
preload=True,
)
reref, ref_data = _apply_reference(epochs.copy(), ref_from=["EEG 001", "EEG 002"])
assert reref.info["custom_ref_applied"]
_test_reference(epochs, reref, ref_data, ["EEG 001", "EEG 002"])
# Test re-referencing Evoked object
evoked = epochs.average()
reref, ref_data = _apply_reference(evoked.copy(), ref_from=["EEG 001", "EEG 002"])
assert reref.info["custom_ref_applied"]
_test_reference(evoked, reref, ref_data, ["EEG 001", "EEG 002"])
# Referencing needs data to be preloaded
raw_np = read_raw_fif(fif_fname, preload=False)
pytest.raises(RuntimeError, _apply_reference, raw_np, ["EEG 001"])
# Test having inactive SSP projections that deal with channels involved
# during re-referencing
raw = read_raw_fif(fif_fname, preload=True)
raw.add_proj(
Projection(
active=False,
data=dict(
col_names=["EEG 001", "EEG 002"],
row_names=None,
data=np.array([[1, 1]]),
ncol=2,
nrow=1,
),
desc="test",
kind=1,
)
)
# Projection concerns channels mentioned in projector
with pytest.raises(RuntimeError, match="Inactive signal space"):
_apply_reference(raw, ["EEG 001"])
# Projection does not concern channels mentioned in projector, no error
_apply_reference(raw, ["EEG 003"], ["EEG 004"])
# CSD cannot be rereferenced
with raw.info._unlock():
raw.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_CSD
with pytest.raises(RuntimeError, match="Cannot set.* type 'CSD'"):
raw.set_eeg_reference()
@testing.requires_testing_data
def test_set_eeg_reference():
"""Test rereference eeg data."""
raw = read_raw_fif(fif_fname, preload=True)
with raw.info._unlock():
raw.info["projs"] = []
# Test setting an average reference projection
assert not _has_eeg_average_ref_proj(raw.info)
reref, ref_data = set_eeg_reference(raw, projection=True)
assert _has_eeg_average_ref_proj(reref.info)
assert not reref.info["projs"][0]["active"]
assert ref_data is None
reref.apply_proj()
eeg_chans = [raw.ch_names[ch] for ch in pick_types(raw.info, meg=False, eeg=True)]
_test_reference(
raw, reref, ref_data, [ch for ch in eeg_chans if ch not in raw.info["bads"]]
)
# Test setting an average reference when one was already present
with pytest.warns(RuntimeWarning, match="untouched"):
reref, ref_data = set_eeg_reference(raw, copy=False, projection=True)
assert ref_data is None
# Test setting an average reference on non-preloaded data
raw_nopreload = read_raw_fif(fif_fname, preload=False)
with raw_nopreload.info._unlock():
raw_nopreload.info["projs"] = []
reref, ref_data = set_eeg_reference(raw_nopreload, projection=True)
assert _has_eeg_average_ref_proj(reref.info)
assert not reref.info["projs"][0]["active"]
# Rereference raw data by creating a copy of original data
reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"], copy=True)
assert reref.info["custom_ref_applied"]
_test_reference(raw, reref, ref_data, ["EEG 001", "EEG 002"])
# Test that data is modified in place when copy=False
reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"], copy=False)
assert raw is reref
# Test moving from custom to average reference
reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"])
reref, _ = set_eeg_reference(reref, projection=True)
assert _has_eeg_average_ref_proj(reref.info)
assert not reref.info["custom_ref_applied"]
# When creating an average reference fails, make sure the
# custom_ref_applied flag remains untouched.
reref = raw.copy()
with reref.info._unlock():
reref.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_ON
reref.pick(picks="meg") # Cause making average ref fail
# should have turned it off
assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_OFF
with pytest.raises(ValueError, match="found to rereference"):
set_eeg_reference(reref, projection=True)
# Test moving from average to custom reference
reref, ref_data = set_eeg_reference(raw, projection=True)
reref, _ = set_eeg_reference(reref, ["EEG 001", "EEG 002"])
assert not _has_eeg_average_ref_proj(reref.info)
assert len(reref.info["projs"]) == 0
assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON
# Test that disabling the reference does not change the data
assert _has_eeg_average_ref_proj(raw.info)
reref, _ = set_eeg_reference(raw, [])
assert_array_equal(raw._data, reref._data)
assert not _has_eeg_average_ref_proj(reref.info)
# make sure ref_channels=[] removes average reference projectors
assert _has_eeg_average_ref_proj(raw.info)
reref, _ = set_eeg_reference(raw, [])
assert not _has_eeg_average_ref_proj(reref.info)
# Test that average reference gives identical results when calculated
# via SSP projection (projection=True) or directly (projection=False)
with raw.info._unlock():
raw.info["projs"] = []
reref_1, _ = set_eeg_reference(raw.copy(), projection=True)
reref_1.apply_proj()
reref_2, _ = set_eeg_reference(raw.copy(), projection=False)
assert_allclose(reref_1._data, reref_2._data, rtol=1e-6, atol=1e-15)
# Test average reference without projection
reref, ref_data = set_eeg_reference(
raw.copy(), ref_channels="average", projection=False
)
_test_reference(raw, reref, ref_data, eeg_chans)
with pytest.raises(ValueError, match='supported for ref_channels="averag'):
set_eeg_reference(raw, [], True, True)
with pytest.raises(ValueError, match='supported for ref_channels="averag'):
set_eeg_reference(raw, ["EEG 001"], True, True)
@pytest.mark.parametrize(
"ch_type, msg",
[
("auto", ("ECoG",)),
("ecog", ("ECoG",)),
("dbs", ("DBS",)),
(["ecog", "dbs"], ("ECoG", "DBS")),
],
)
@pytest.mark.parametrize("projection", [False, True])
def test_set_eeg_reference_ch_type(ch_type, msg, projection):
"""Test setting EEG reference for ECoG or DBS."""
# gh-6454
# gh-8739 added DBS
ch_names = ["ECOG01", "ECOG02", "DBS01", "DBS02", "MISC"]
rng = np.random.RandomState(0)
data = rng.randn(5, 1000)
raw = RawArray(
data, create_info(ch_names, 1000.0, ["ecog"] * 2 + ["dbs"] * 2 + ["misc"])
)
if ch_type == "auto":
ref_ch = ch_names[:2]
else:
ref_ch = raw.copy().pick(picks=ch_type).ch_names
with catch_logging() as log:
reref, ref_data = set_eeg_reference(
raw.copy(), ch_type=ch_type, projection=projection, verbose=True
)
if not projection:
assert f"Applying a custom {msg}" in log.getvalue()
assert reref.info["custom_ref_applied"] # gh-7350
_test_reference(raw, reref, ref_data, ref_ch)
match = "no EEG data found" if projection else "No channels supplied"
with pytest.raises(ValueError, match=match):
set_eeg_reference(raw, ch_type="eeg", projection=projection)
# gh-8739
raw2 = RawArray(data, create_info(5, 1000.0, ["mag"] * 4 + ["misc"]))
with pytest.raises(
ValueError, match="No EEG, ECoG, sEEG or DBS channels found to rereference."
):
set_eeg_reference(raw2, ch_type="auto", projection=projection)
@testing.requires_testing_data
def test_set_eeg_reference_rest():
"""Test setting a REST reference."""
raw = read_raw_fif(fif_fname).crop(0, 1).pick(picks="eeg").load_data()
raw.info["bads"] = ["EEG 057"] # should be excluded
same = [raw.ch_names.index(raw.info["bads"][0])]
picks = np.setdiff1d(np.arange(len(raw.ch_names)), same)
trans = None
# Use fixed values from old sphere fit to reduce lines changed with fixed algorithm
sphere = make_sphere_model(
[-0.00413508, 0.01598787, 0.05175598],
0.09100286249131773,
)
src = setup_volume_source_space(pos=20.0, sphere=sphere, exclude=30.0)
assert src[0]["nuse"] == 223 # low but fast
fwd = make_forward_solution(raw.info, trans, src, sphere)
orig_data = raw.get_data()
avg_data = raw.copy().set_eeg_reference("average").get_data()
assert_array_equal(avg_data[same], orig_data[same]) # not processed
raw.set_eeg_reference("REST", forward=fwd)
rest_data = raw.get_data()
assert_array_equal(rest_data[same], orig_data[same])
# should be more similar to an avg ref than nose ref
orig_corr = np.corrcoef(rest_data[picks].ravel(), orig_data[picks].ravel())[0, 1]
avg_corr = np.corrcoef(rest_data[picks].ravel(), avg_data[picks].ravel())[0, 1]
assert -0.6 < orig_corr < -0.5
assert 0.1 < avg_corr < 0.2
# and applying an avg ref after should work
avg_after = raw.set_eeg_reference("average").get_data()
assert_allclose(avg_after, avg_data, atol=1e-12)
with pytest.raises(TypeError, match='forward when ref_channels="REST"'):
raw.set_eeg_reference("REST")
fwd_bad = pick_channels_forward(fwd, raw.ch_names[:-1])
with pytest.raises(ValueError, match="Missing channels"):
raw.set_eeg_reference("REST", forward=fwd_bad)
# compare to FieldTrip
evoked = read_evokeds(ave_fname, baseline=(None, 0))[0]
evoked.info["bads"] = []
evoked.pick(picks="eeg")
assert len(evoked.ch_names) == 60
# Data obtained from FieldTrip with something like (after evoked.save'ing
# then scipy.io.savemat'ing fwd['sol']['data']):
# dat = ft_read_data('ft-ave.fif');
# load('leadfield.mat', 'G');
# dat_ref = ft_preproc_rereference(dat, 'all', 'rest', true, G);
# sprintf('%g ', dat_ref(:, 171));
data_array = "-3.3265e-05 -3.2419e-05 -3.18758e-05 -3.24079e-05 -3.39801e-05 -3.40573e-05 -3.24163e-05 -3.26896e-05 -3.33814e-05 -3.54734e-05 -3.51289e-05 -3.53229e-05 -3.51532e-05 -3.53149e-05 -3.4505e-05 -3.03462e-05 -2.81848e-05 -3.08895e-05 -3.27158e-05 -3.4605e-05 -3.47728e-05 -3.2459e-05 -3.06552e-05 -2.53255e-05 -2.69671e-05 -2.83425e-05 -3.12836e-05 -3.30965e-05 -3.34099e-05 -3.32766e-05 -3.32256e-05 -3.36385e-05 -3.20796e-05 -2.7108e-05 -2.47054e-05 -2.49589e-05 -2.7382e-05 -3.09774e-05 -3.12003e-05 -3.1246e-05 -3.07572e-05 -2.64942e-05 -2.25505e-05 -2.67194e-05 -2.86e-05 -2.94903e-05 -2.96249e-05 -2.92653e-05 -2.86472e-05 -2.81016e-05 -2.69737e-05 -2.48076e-05 -3.00473e-05 -2.73404e-05 -2.60153e-05 -2.41608e-05 -2.61937e-05 -2.5539e-05 -2.47104e-05 -2.35194e-05" # noqa: E501
want = np.array(data_array.split(" "), float)
norm = np.linalg.norm(want)
idx = np.argmin(np.abs(evoked.times - 0.083))
assert idx == 170
old = evoked.data[:, idx].ravel()
exp_var = 1 - np.linalg.norm(want - old) / norm
assert 0.006 < exp_var < 0.008
evoked.set_eeg_reference("REST", forward=fwd)
exp_var_old = 1 - np.linalg.norm(evoked.data[:, idx] - old) / norm
assert 0.005 < exp_var_old <= 0.009
exp_var = 1 - np.linalg.norm(evoked.data[:, idx] - want) / norm
assert 0.995 < exp_var <= 1
@testing.requires_testing_data
@pytest.mark.parametrize("inst_type", ["raw", "epochs"])
@pytest.mark.parametrize(
"ref_channels, expectation",
[
(
{2: "EEG 001"},
pytest.raises(
TypeError,
match="Keys in the ref_channels dict must be strings. "
"Your dict has keys of type int.",
),
),
(
{"EEG 001": (1, 2)},
pytest.raises(
TypeError,
match="Values in the ref_channels dict must be strings. "
"Your dict has values of type int.",
),
),
(
{"EEG 001": [1, 2]},
pytest.raises(
TypeError,
match="Values in the ref_channels dict must be strings. "
"Your dict has values of type int.",
),
),
(
{"EEG 999": "EEG 001"},
pytest.raises(
ValueError,
match=r"ref_channels dict contains invalid key\(s\) \(EEG 999\) "
"that are not names of channels in the instance.",
),
),
(
{"EEG 001": "EEG 999"},
pytest.raises(
ValueError,
match=r"ref_channels dict contains invalid value\(s\) \(EEG 999\) "
"that are not names of channels in the instance.",
),
),
(
{"EEG 001": "EEG 057"},
pytest.warns(
RuntimeWarning,
match=r"ref_channels dict contains value\(s\) \(EEG 057\) "
"that are marked as bad channels.",
),
),
(
{"EEG 001": "STI 001"},
pytest.warns(
RuntimeWarning,
match=(
r"Channel EEG 001 \(eeg\) is referenced to channel "
r"STI 001 which is a different channel type \(stim\)."
),
),
),
(
{"EEG 001": "EEG 001"},
pytest.warns(
RuntimeWarning,
match=(
"Channel EEG 001 is self-referenced, "
"which will nullify the channel."
),
),
),
(
{"EEG 001": "EEG 002", "EEG 002": "EEG 003", "EEG 003": "EEG 005"},
nullcontext(),
),
(
{
"EEG 001": ["EEG 002", "EEG 003"],
"EEG 002": "EEG 003",
"EEG 003": "EEG 005",
},
nullcontext(),
),
],
)
def test_set_eeg_reference_dict(ref_channels, inst_type, expectation):
"""Test setting dict-based reference."""
if inst_type == "raw":
inst = read_raw_fif(fif_fname).crop(0, 1).pick(picks=["eeg", "stim"])
# Test re-referencing Epochs object
elif inst_type == "epochs":
raw = read_raw_fif(fif_fname, preload=False)
events = read_events(eve_fname)
inst = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
preload=False,
)
with pytest.raises(
RuntimeError,
match="By default, MNE does not load data.*Applying a reference requires.*",
):
inst.set_eeg_reference(ref_channels=ref_channels)
inst.load_data()
inst.info["bads"] = ["EEG 057"]
with expectation:
reref, _ = set_eeg_reference(inst.copy(), ref_channels, copy=False)
if isinstance(expectation, nullcontext):
# Check that the custom_ref_applied is set correctly:
assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON
# Get raw data
_data = inst._data
# Get that channels that were and weren't re-referenced:
ch_raw = pick_channels(
inst.ch_names,
[ch for ch in inst.ch_names if ch not in list(ref_channels.keys())],
)
ch_reref = pick_channels(inst.ch_names, list(ref_channels.keys()), ordered=True)
# Check that the non re-reference channels are untouched:
assert_allclose(
_data[..., ch_raw, :], reref._data[..., ch_raw, :], 1e-6, atol=1e-15
)
# Compute the reference data:
ref_data = []
for val in ref_channels.values():
if isinstance(val, str):
val = [val] # pick_channels expects a list
ref_data.append(
_data[..., pick_channels(inst.ch_names, val, ordered=True), :].mean(
-2, keepdims=True
)
)
if inst_type == "epochs":
ref_data = np.concatenate(ref_data, axis=1)
else:
ref_data = np.squeeze(np.array(ref_data))
assert_allclose(
_data[..., ch_reref, :],
reref._data[..., ch_reref, :] + ref_data,
1e-6,
atol=1e-15,
)
@testing.requires_testing_data
@pytest.mark.parametrize("inst_type", ("raw", "epochs", "evoked"))
def test_set_bipolar_reference(inst_type):
"""Test bipolar referencing."""
raw = read_raw_fif(fif_fname, preload=True)
raw.apply_proj()
if inst_type == "raw":
inst = raw
del raw
elif inst_type in ["epochs", "evoked"]:
events = find_events(raw, stim_channel="STI 014")
epochs = Epochs(raw, events, tmin=-0.3, tmax=0.7, preload=True)
inst = epochs
if inst_type == "evoked":
inst = epochs.average()
del epochs
ch_info = {"kind": FIFF.FIFFV_EOG_CH, "extra": "some extra value"}
with pytest.raises(KeyError, match="key errantly present"):
set_bipolar_reference(inst, "EEG 001", "EEG 002", "bipolar", ch_info)
ch_info.pop("extra")
reref = set_bipolar_reference(inst, "EEG 001", "EEG 002", "bipolar", ch_info)
assert reref.info["custom_ref_applied"]
# Compare result to a manual calculation
a = inst.copy().pick(["EEG 001", "EEG 002"])
a = a._data[..., 0, :] - a._data[..., 1, :]
b = reref.copy().pick(["bipolar"])._data[..., 0, :]
assert_allclose(a, b)
# Original channels should be replaced by a virtual one
assert "EEG 001" not in reref.ch_names
assert "EEG 002" not in reref.ch_names
assert "bipolar" in reref.ch_names
# Check channel information
bp_info = reref.info["chs"][reref.ch_names.index("bipolar")]
an_info = inst.info["chs"][inst.ch_names.index("EEG 001")]
for key in bp_info:
if key == "coil_type":
assert bp_info[key] == FIFF.FIFFV_COIL_EEG_BIPOLAR, key
elif key == "kind":
assert bp_info[key] == FIFF.FIFFV_EOG_CH, key
elif key != "ch_name":
assert_equal(bp_info[key], an_info[key], err_msg=key)
# Minimalist call
reref = set_bipolar_reference(inst, "EEG 001", "EEG 002")
assert "EEG 001-EEG 002" in reref.ch_names
# Minimalist call with twice the same anode
reref = set_bipolar_reference(
inst, ["EEG 001", "EEG 001", "EEG 002"], ["EEG 002", "EEG 003", "EEG 003"]
)
assert "EEG 001-EEG 002" in reref.ch_names
assert "EEG 001-EEG 003" in reref.ch_names
# Set multiple references at once
reref = set_bipolar_reference(
inst,
["EEG 001", "EEG 003"],
["EEG 002", "EEG 004"],
["bipolar1", "bipolar2"],
[{"kind": FIFF.FIFFV_EOG_CH}, {"kind": FIFF.FIFFV_EOG_CH}],
)
a = inst.copy().pick(["EEG 001", "EEG 002", "EEG 003", "EEG 004"])
a = np.concatenate(
[
a._data[..., :1, :] - a._data[..., 1:2, :],
a._data[..., 2:3, :] - a._data[..., 3:4, :],
],
axis=-2,
)
b = reref.copy().pick(["bipolar1", "bipolar2"])._data
assert_allclose(a, b)
# Test creating a bipolar reference that doesn't involve EEG channels:
# it should not set the custom_ref_applied flag
reref = set_bipolar_reference(
inst,
"MEG 0111",
"MEG 0112",
ch_info={"kind": FIFF.FIFFV_MEG_CH},
verbose="error",
)
assert not reref.info["custom_ref_applied"]
assert "MEG 0111-MEG 0112" in reref.ch_names
# Test a battery of invalid inputs
pytest.raises(
ValueError,
set_bipolar_reference,
inst,
"EEG 001",
["EEG 002", "EEG 003"],
"bipolar",
)
pytest.raises(
ValueError,
set_bipolar_reference,
inst,
["EEG 001", "EEG 002"],
"EEG 003",
"bipolar",
)
pytest.raises(
ValueError,
set_bipolar_reference,
inst,
"EEG 001",
"EEG 002",
["bipolar1", "bipolar2"],
)
pytest.raises(
ValueError,
set_bipolar_reference,
inst,
"EEG 001",
"EEG 002",
"bipolar",
ch_info=[{"foo": "bar"}, {"foo": "bar"}],
)
pytest.raises(
ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", ch_name="EEG 003"
)
# Test if bad anode/cathode raises error if on_bad="raise"
inst.info["bads"] = ["EEG 001"]
pytest.raises(
ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="raise"
)
inst.info["bads"] = ["EEG 002"]
pytest.raises(
ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="raise"
)
# Test if bad anode/cathode raises warning if on_bad="warn"
inst.info["bads"] = ["EEG 001"]
pytest.warns(
RuntimeWarning, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="warn"
)
inst.info["bads"] = ["EEG 002"]
pytest.warns(
RuntimeWarning, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="warn"
)
def _check_channel_names(inst, ref_names):
"""Check channel names."""
if isinstance(ref_names, str):
ref_names = [ref_names]
# Test that the names of the reference channels are present in `ch_names`
ref_idx = pick_channels(inst.info["ch_names"], ref_names)
assert len(ref_idx) == len(ref_names)
# Test that the names of the reference channels are present in the `chs`
# list
inst.info._check_consistency() # Should raise no exceptions
@testing.requires_testing_data
def test_add_reference():
"""Test adding a reference."""
raw = read_raw_fif(fif_fname, preload=True)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
# check if channel already exists
pytest.raises(ValueError, add_reference_channels, raw, raw.info["ch_names"][0])
# add reference channel to Raw
raw_ref = add_reference_channels(raw, "Ref", copy=True)
assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1)
assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
_check_channel_names(raw_ref, "Ref")
orig_nchan = raw.info["nchan"]
raw = add_reference_channels(raw, "Ref", copy=False)
assert_array_equal(raw._data, raw_ref._data)
assert_equal(raw.info["nchan"], orig_nchan + 1)
_check_channel_names(raw, "Ref")
# for Neuromag fif's, the reference electrode location is placed in
# elements [3:6] of each "data" electrode location
assert_allclose(
raw.info["chs"][-1]["loc"][:3], raw.info["chs"][picks_eeg[0]]["loc"][3:6], 1e-6
)
ref_idx = raw.ch_names.index("Ref")
ref_data, _ = raw[ref_idx]
assert_array_equal(ref_data, 0)
# add reference channel to Raw when no digitization points exist
raw = read_raw_fif(fif_fname).crop(0, 1).load_data()
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
del raw.info["dig"]
raw_ref = add_reference_channels(raw, "Ref", copy=True)
assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1)
assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
_check_channel_names(raw_ref, "Ref")
orig_nchan = raw.info["nchan"]
raw = add_reference_channels(raw, "Ref", copy=False)
assert_array_equal(raw._data, raw_ref._data)
assert_equal(raw.info["nchan"], orig_nchan + 1)
_check_channel_names(raw, "Ref")
# Test adding an existing channel as reference channel
pytest.raises(ValueError, add_reference_channels, raw, raw.info["ch_names"][0])
# add two reference channels to Raw
raw_ref = add_reference_channels(raw, ["M1", "M2"], copy=True)
_check_channel_names(raw_ref, ["M1", "M2"])
assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 2)
assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
assert_array_equal(raw_ref._data[-2:, :], 0)
raw = add_reference_channels(raw, ["M1", "M2"], copy=False)
_check_channel_names(raw, ["M1", "M2"])
ref_idx = raw.ch_names.index("M1")
ref_idy = raw.ch_names.index("M2")
ref_data, _ = raw[[ref_idx, ref_idy]]
assert_array_equal(ref_data, 0)
# add reference channel to epochs
raw = read_raw_fif(fif_fname, preload=True)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
epochs = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
picks=picks_eeg,
preload=True,
)
# default: proj=True, after which adding a Ref channel is prohibited
pytest.raises(RuntimeError, add_reference_channels, epochs, "Ref")
# create epochs in delayed mode, allowing removal of CAR when re-reffing
epochs = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
picks=picks_eeg,
preload=True,
proj="delayed",
)
epochs_ref = add_reference_channels(epochs, "Ref", copy=True)
assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 1)
_check_channel_names(epochs_ref, "Ref")
ref_idx = epochs_ref.ch_names.index("Ref")
ref_data = epochs_ref.get_data(picks=[ref_idx])[:, 0]
assert_array_equal(ref_data, 0)
picks_eeg = pick_types(epochs.info, meg=False, eeg=True)
assert_array_equal(epochs.get_data(picks_eeg), epochs_ref.get_data(picks_eeg))
# add two reference channels to epochs
raw = read_raw_fif(fif_fname, preload=True)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
# create epochs in delayed mode, allowing removal of CAR when re-reffing
epochs = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
picks=picks_eeg,
preload=True,
proj="delayed",
)
with pytest.warns(RuntimeWarning, match="for this channel is unknown or ambiguous"):
epochs_ref = add_reference_channels(epochs, ["M1", "M2"], copy=True)
assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 2)
_check_channel_names(epochs_ref, ["M1", "M2"])
ref_idx = epochs_ref.ch_names.index("M1")
ref_idy = epochs_ref.ch_names.index("M2")
assert_equal(epochs_ref.info["chs"][ref_idx]["ch_name"], "M1")
assert_equal(epochs_ref.info["chs"][ref_idy]["ch_name"], "M2")
ref_data = epochs_ref.get_data([ref_idx, ref_idy])
assert_array_equal(ref_data, 0)
picks_eeg = pick_types(epochs.info, meg=False, eeg=True)
assert_array_equal(epochs.get_data(picks_eeg), epochs_ref.get_data(picks_eeg))
# add reference channel to evoked
raw = read_raw_fif(fif_fname, preload=True)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
# create epochs in delayed mode, allowing removal of CAR when re-reffing
epochs = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
picks=picks_eeg,
preload=True,
proj="delayed",
)
evoked = epochs.average()
evoked_ref = add_reference_channels(evoked, "Ref", copy=True)
assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 1)
_check_channel_names(evoked_ref, "Ref")
ref_idx = evoked_ref.ch_names.index("Ref")
ref_data = evoked_ref.data[ref_idx, :]
assert_array_equal(ref_data, 0)
picks_eeg = pick_types(evoked.info, meg=False, eeg=True)
assert_array_equal(evoked.data[picks_eeg, :], evoked_ref.data[picks_eeg, :])
# add two reference channels to evoked
raw = read_raw_fif(fif_fname, preload=True)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
# create epochs in delayed mode, allowing removal of CAR when re-reffing
epochs = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
picks=picks_eeg,
preload=True,
proj="delayed",
)
evoked = epochs.average()
with pytest.warns(RuntimeWarning, match="for this channel is unknown or ambiguous"):
evoked_ref = add_reference_channels(evoked, ["M1", "M2"], copy=True)
assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 2)
_check_channel_names(evoked_ref, ["M1", "M2"])
ref_idx = evoked_ref.ch_names.index("M1")
ref_idy = evoked_ref.ch_names.index("M2")
ref_data = evoked_ref.data[[ref_idx, ref_idy], :]
assert_array_equal(ref_data, 0)
picks_eeg = pick_types(evoked.info, meg=False, eeg=True)
assert_array_equal(evoked.data[picks_eeg, :], evoked_ref.data[picks_eeg, :])
# Test invalid inputs
raw = read_raw_fif(fif_fname, preload=False)
with pytest.raises(RuntimeError, match="loaded"):
add_reference_channels(raw, ["Ref"])
raw.load_data()
with pytest.raises(ValueError, match="Channel.*already.*"):
add_reference_channels(raw, raw.ch_names[:1])
with pytest.raises(TypeError, match="instance of"):
add_reference_channels(raw, 1)
# gh-10878
raw = read_raw_fif(raw_fname).crop(0, 1, include_tmax=False).load_data()
data = raw.copy().add_reference_channels(["REF"]).pick(picks="eeg")
data = data.get_data()
epochs = make_fixed_length_epochs(raw).load_data()
data_2 = epochs.copy().add_reference_channels(["REF"]).pick(picks="eeg")
data_2 = data_2.get_data(copy=False)[0]
assert_allclose(data, data_2)
evoked = epochs.average()
data_3 = evoked.copy().add_reference_channels(["REF"]).pick(picks="eeg")
data_3 = data_3.get_data()
assert_allclose(data, data_3)
@pytest.mark.parametrize("n_ref", (1, 2))
def test_add_reorder(n_ref):
"""Test that a reference channel can be added and then data reordered."""
# gh-8300
raw = read_raw_fif(raw_fname).crop(0, 0.1).del_proj().pick("eeg")
assert len(raw.ch_names) == 60
chs = [f"EEG {60 + ii:03}" for ii in range(1, n_ref)] + ["EEG 000"]
with pytest.raises(RuntimeError, match="preload"):
with _record_warnings(): # ignore multiple warning
add_reference_channels(raw, chs, copy=False)
raw.load_data()
if n_ref == 1:
ctx = nullcontext()
else:
assert n_ref == 2
ctx = pytest.warns(RuntimeWarning, match="this channel is unknown or ambiguous")
with ctx:
add_reference_channels(raw, chs, copy=False)
data = raw.get_data()
assert_array_equal(data[-1], 0.0)
assert raw.ch_names[-n_ref:] == chs
raw.reorder_channels(raw.ch_names[-1:] + raw.ch_names[:-1])
assert raw.ch_names == [f"EEG {ii:03}" for ii in range(60 + n_ref)]
data_new = raw.get_data()
data_new = np.concatenate([data_new[1:], data_new[:1]])
assert_allclose(data, data_new)
def test_bipolar_combinations():
"""Test bipolar channel generation."""
ch_names = ["CH" + str(ni + 1) for ni in range(10)]
info = create_info(
ch_names=ch_names, sfreq=1000.0, ch_types=["eeg"] * len(ch_names)
)
raw_data = np.random.randn(len(ch_names), 1000)
raw = RawArray(raw_data, info)
def _check_bipolar(raw_test, ch_a, ch_b):
picks = [raw_test.ch_names.index(ch_a + "-" + ch_b)]
get_data_res = raw_test.get_data(picks=picks)[0, :]
manual_a = raw_data[ch_names.index(ch_a), :]
manual_b = raw_data[ch_names.index(ch_b), :]
assert_array_equal(get_data_res, manual_a - manual_b)
# test classic EOG/ECG bipolar reference (only two channels per pair).
raw_test = set_bipolar_reference(raw, ["CH2"], ["CH1"], copy=True)
_check_bipolar(raw_test, "CH2", "CH1")
# test all combinations.
a_channels, b_channels = zip(*itertools.combinations(ch_names, 2))
a_channels, b_channels = list(a_channels), list(b_channels)
raw_test = set_bipolar_reference(raw, a_channels, b_channels, copy=True)
for ch_a, ch_b in zip(a_channels, b_channels):
_check_bipolar(raw_test, ch_a, ch_b)
# check if reference channels have been dropped.
assert len(raw_test.ch_names) == len(a_channels)
raw_test = set_bipolar_reference(
raw, a_channels, b_channels, drop_refs=False, copy=True
)
# check if reference channels have been kept correctly.
assert len(raw_test.ch_names) == len(a_channels) + len(ch_names)
for idx, ch_label in enumerate(ch_names):
manual_ch = raw_data[np.newaxis, idx]
assert_array_equal(raw_test.get_data(ch_label), manual_ch)
# test bipolars with a channel in both list (anode & cathode).
raw_test = set_bipolar_reference(raw, ["CH2", "CH1"], ["CH1", "CH2"], copy=True)
_check_bipolar(raw_test, "CH2", "CH1")
_check_bipolar(raw_test, "CH1", "CH2")
# test if bipolar channel is bad if anode is a bad channel
raw.info["bads"] = ["CH1"]
raw_test = set_bipolar_reference(
raw, ["CH1"], ["CH2"], on_bad="ignore", ch_name="bad_bipolar", copy=True
)
assert raw_test.info["bads"] == ["bad_bipolar"]
# test if bipolar channel is bad if cathode is a bad channel
raw.info["bads"] = ["CH2"]
raw_test = set_bipolar_reference(
raw, ["CH1"], ["CH2"], on_bad="ignore", ch_name="bad_bipolar", copy=True
)
assert raw_test.info["bads"] == ["bad_bipolar"]