# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import datetime
import itertools
import re
from pathlib import Path
import numpy as np
import pytest
from mne import create_info
from mne.annotations import Annotations
from mne.datasets import testing
from mne.io import RawArray, read_raw_fif
from mne.preprocessing import annotate_amplitude
date = datetime.datetime(2021, 12, 10, 7, 52, 24, 405305, tzinfo=datetime.timezone.utc)
data_path = Path(testing.data_path(download=False))
skip_fname = data_path / "misc" / "intervalrecording_raw.fif"
@pytest.mark.parametrize("meas_date", (None, date))
@pytest.mark.parametrize("first_samp", (0, 10000))
def test_annotate_amplitude(meas_date, first_samp):
"""Test automatic annotation for segments based on peak-to-peak value."""
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(n_ch, 1000.0, "eeg")
# from annotate_flat: test first_samp != for gh-6295
raw = RawArray(data, info, first_samp=first_samp)
raw.info["bads"] = [raw.ch_names[-1]]
raw.set_meas_date(meas_date)
# -- test bad channels spatial marking --
for perc, dur in itertools.product((5, 99.9, 100.0), (0.005, 0.95, 0.99)):
kwargs = dict(bad_percent=perc, min_duration=dur)
# test entire channel flat
raw_ = raw.copy()
raw_._data[0] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, **kwargs)
assert len(annots) == 0
assert bads == ["0"]
# test multiple channels flat
raw_ = raw.copy()
raw_._data[0] = 0.0
raw_._data[2] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, **kwargs)
assert len(annots) == 0
assert bads == ["0", "2"]
# test entire channel drifting
raw_ = raw.copy()
raw_._data[0] = np.arange(0, raw.times.size * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs)
assert len(annots) == 0
assert bads == ["0"]
# test multiple channels drifting
raw_ = raw.copy()
raw_._data[0] = np.arange(0, raw.times.size * 10, 10)
raw_._data[2] = np.arange(0, raw.times.size * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs)
assert len(annots) == 0
assert bads == ["0", "2"]
# -- test bad channels temporal marking --
# flat channel for the 20% last points
n_good_times = int(round(0.8 * n_times))
raw_ = raw.copy()
raw_._data[0, n_good_times:] = 0.0
for perc in (5, 20):
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, bad_percent=perc)
assert len(annots) == 0
assert bads == ["0"]
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, bad_percent=20.1)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], meas_date, first_samp, n_good_times, -1)
# test multiple channels flat and multiple channels drift
raw_ = raw.copy()
raw_._data[0, 800:] = 0.0
raw_._data[1, 850:950] = 0.0
raw_._data[2, :200] = np.arange(0, 200 * 10, 10)
raw_._data[2, 200:] += raw_._data[2, 199] # add offset for next samples
raw_._data[3, 50:150] = np.arange(0, 100 * 10, 10)
raw_._data[3, 150:] += raw_._data[3, 149] # add offset for next samples
for perc in (5, 10):
annots, bads = annotate_amplitude(raw_, peak=5, flat=0.0, bad_percent=perc)
assert len(annots) == 0
assert bads == ["0", "1", "2", "3"]
for perc in (10.1, 20):
annots, bads = annotate_amplitude(raw_, peak=5, flat=0.0, bad_percent=perc)
assert len(annots) == 2
assert bads == ["0", "2"]
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 50 if annot["description"] == "BAD_peak" else 850
stop_idx = 149 if annot["description"] == "BAD_peak" else 949
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
annots, bads = annotate_amplitude(raw_, peak=5, flat=0.0, bad_percent=20.1)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 0 if annot["description"] == "BAD_peak" else 800
stop_idx = 199 if annot["description"] == "BAD_peak" else -1
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
# test flat on already marked bad channel
raw_ = raw.copy()
raw_._data[-1, :] = 0.0 # this channel is already in info['bads']
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, bad_percent=5)
assert len(annots) == 0
assert len(bads) == 0
# test drift on already marked bad channel
raw_ = raw.copy()
raw_._data[-1, :] = np.arange(0, raw.times.size * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=5)
assert len(annots) == 0
assert len(bads) == 0
@pytest.mark.parametrize("meas_date", (None, date))
@pytest.mark.parametrize("first_samp", (0, 10000))
def test_annotate_amplitude_with_overlap(meas_date, first_samp):
"""Test cases with overlap between annotations."""
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(n_ch, 1000.0, "eeg")
# from annotate_flat: test first_samp != for gh-6295
raw = RawArray(data, info, first_samp=first_samp)
raw.info["bads"] = [raw.ch_names[-1]]
raw.set_meas_date(meas_date)
# -- overlap between peak and flat --
raw_ = raw.copy()
raw_._data[0, 800:] = 0.0
raw_._data[1, 700:900] = np.arange(0, 200 * 10, 10)
raw_._data[1, 900:] += raw_._data[1, 899] # add offset for next samples
annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=25)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 700 if annot["description"] == "BAD_peak" else 800
stop_idx = 899 if annot["description"] == "BAD_peak" else -1
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
# -- overlap between peak and peak on same channel --
raw_ = raw.copy()
raw_._data[0, 700:900] = np.arange(0, 200 * 10, 10)
raw_._data[0, 800:] = np.arange(1000, 300 * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=50)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_peak"
_check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1)
# -- overlap between flat and flat on different channel --
raw_ = raw.copy()
raw_._data[0, 700:900] = 0.0
raw_._data[1, 800:] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.01, bad_percent=50)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1)
@pytest.mark.parametrize("meas_date", (None, date))
@pytest.mark.parametrize("first_samp", (0, 10000))
def test_annotate_amplitude_multiple_ch_types(meas_date, first_samp):
"""Test cases with several channel types."""
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(
n_ch, 1000.0, ["eeg"] * 3 + ["mag"] * 2 + ["grad"] * 4 + ["eeg"] * 2
)
# from annotate_flat: test first_samp != for gh-6295
raw = RawArray(data, info, first_samp=first_samp)
raw.info["bads"] = [raw.ch_names[-1]]
raw.set_meas_date(meas_date)
# -- 2 channel types both to annotate --
raw_ = raw.copy()
raw_._data[1, 800:] = 0.0
raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples
annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 0 if annot["description"] == "BAD_peak" else 800
stop_idx = 199 if annot["description"] == "BAD_peak" else -1
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
# -- 2 channel types, one flat picked, one not picked --
raw_ = raw.copy()
raw_._data[1, 800:] = 0.0
raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples
annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50, picks="eeg")
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
_check_annotation(raw_, annots[0], meas_date, first_samp, 800, -1)
assert annots[0]["description"] == "BAD_flat"
# -- 2 channel types, one flat, one not picked, reverse --
raw_ = raw.copy()
raw_._data[1, 800:] = 0.0
raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples
annots, bads = annotate_amplitude(
raw_, peak=5, flat=0, bad_percent=50, picks="grad"
)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
_check_annotation(raw_, annots[0], meas_date, first_samp, 0, 199)
assert annots[0]["description"] == "BAD_peak"
@testing.requires_testing_data
def test_flat_bad_acq_skip():
"""Test that acquisition skips are handled properly."""
# -- file with a couple of skip and flat channels --
raw = read_raw_fif(skip_fname, preload=True)
annots, bads = annotate_amplitude(raw, flat=0)
assert len(annots) == 0
assert bads == [
f"MEG{num.zfill(4)}"
for num in "141 331 421 431 611 641 1011 1021 1031 1241 1421 "
"1741 1841 2011 2131 2141 2241 2531 2541 2611 2621".split()
] # MaxFilter finds the same 21 channels
# -- overlap of flat segment with bad_acq_skip --
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(n_ch, 1000.0, "eeg")
raw = RawArray(data, info, first_samp=0)
raw.info["bads"] = [raw.ch_names[-1]]
bad_acq_skip = Annotations([0.5], [0.2], ["bad_acq_skip"], orig_time=None)
raw.set_annotations(bad_acq_skip)
# add flat channel overlapping with the left edge of bad_acq_skip
raw_ = raw.copy()
raw_._data[0, 400:600] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], None, 0, 400, 499)
# add flat channel overlapping with the right edge of bad_acq_skip
raw_ = raw.copy()
raw_._data[0, 600:800] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], None, 0, 700, 799)
# add flat channel overlapping entirely with bad_acq_skip
raw_ = raw.copy()
raw_._data[0, 200:800] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=41)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
annots = sorted(annots, key=lambda x: x["onset"])
assert all(annot["description"] == "BAD_flat" for annot in annots)
_check_annotation(raw_, annots[0], None, 0, 200, 500)
_check_annotation(raw_, annots[1], None, 0, 700, 799)
def _check_annotation(raw, annot, meas_date, first_samp, start_idx, stop_idx):
"""Util function to check an annotation."""
assert meas_date == annot["orig_time"]
if meas_date is None:
assert np.isclose(raw.times[start_idx], annot["onset"], atol=1e-4)
assert np.isclose(
raw.times[stop_idx], annot["onset"] + annot["duration"], atol=1e-4
)
else:
first_time = first_samp / raw.info["sfreq"] # because of meas_date
assert np.isclose(raw.times[start_idx], annot["onset"] - first_time, atol=1e-4)
assert np.isclose(
raw.times[stop_idx],
annot["onset"] + annot["duration"] - first_time,
atol=1e-4,
)
def test_invalid_arguments():
"""Test error messages raised by invalid arguments."""
n_ch, n_times = 2, 100
data = np.random.RandomState(0).randn(n_ch, n_times)
info = create_info(n_ch, 100.0, "eeg")
raw = RawArray(data, info, first_samp=0)
# negative floats PTP
with pytest.raises(
ValueError,
match="Argument 'flat' should define a positive threshold. Provided: '-1'.",
):
annotate_amplitude(raw, peak=None, flat=-1)
with pytest.raises(
ValueError,
match="Argument 'peak' should define a positive threshold. Provided: '-1'.",
):
annotate_amplitude(raw, peak=-1, flat=None)
# negative PTP threshold for one channel type
with pytest.raises(
ValueError,
match="Argument 'flat' should define positive "
"thresholds. Provided for channel type "
"'eog': '-1'.",
):
annotate_amplitude(raw, peak=None, flat=dict(eeg=1, eog=-1))
with pytest.raises(
ValueError,
match="Argument 'peak' should define positive "
"thresholds. Provided for channel type "
"'eog': '-1'.",
):
annotate_amplitude(raw, peak=dict(eeg=1, eog=-1), flat=None)
# test both PTP set to None
with pytest.raises(
ValueError,
match="At least one of the arguments 'peak' or 'flat' must not be None.",
):
annotate_amplitude(raw, peak=None, flat=None)
# bad_percent outside [0, 100]
with pytest.raises(
ValueError,
match="Argument 'bad_percent' should define a "
"percentage between 0% and 100%. Provided: "
"-1.0%.",
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, bad_percent=-1)
# min_duration negative
with pytest.raises(
ValueError,
match="Argument 'min_duration' should define a "
"positive duration in seconds. Provided: "
"'-1.0' seconds.",
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=-1)
# min_duration equal to the raw duration
with pytest.raises(
ValueError,
match=re.escape(
"Argument 'min_duration' should define a "
"positive duration in seconds shorter than the "
"raw duration (1.0 seconds). Provided: "
"'1.0' seconds."
),
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=1.0)
# min_duration longer than the raw duration
with pytest.raises(
ValueError,
match=re.escape(
"Argument 'min_duration' should define a "
"positive duration in seconds shorter than the "
"raw duration (1.0 seconds). Provided: "
"'10.0' seconds."
),
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=10)