[074d3d]: / mne / preprocessing / tests / test_annotate_amplitude.py

Download this file

398 lines (356 with data), 16.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import datetime
import itertools
import re
from pathlib import Path
import numpy as np
import pytest
from mne import create_info
from mne.annotations import Annotations
from mne.datasets import testing
from mne.io import RawArray, read_raw_fif
from mne.preprocessing import annotate_amplitude
date = datetime.datetime(2021, 12, 10, 7, 52, 24, 405305, tzinfo=datetime.timezone.utc)
data_path = Path(testing.data_path(download=False))
skip_fname = data_path / "misc" / "intervalrecording_raw.fif"
@pytest.mark.parametrize("meas_date", (None, date))
@pytest.mark.parametrize("first_samp", (0, 10000))
def test_annotate_amplitude(meas_date, first_samp):
"""Test automatic annotation for segments based on peak-to-peak value."""
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(n_ch, 1000.0, "eeg")
# from annotate_flat: test first_samp != for gh-6295
raw = RawArray(data, info, first_samp=first_samp)
raw.info["bads"] = [raw.ch_names[-1]]
raw.set_meas_date(meas_date)
# -- test bad channels spatial marking --
for perc, dur in itertools.product((5, 99.9, 100.0), (0.005, 0.95, 0.99)):
kwargs = dict(bad_percent=perc, min_duration=dur)
# test entire channel flat
raw_ = raw.copy()
raw_._data[0] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, **kwargs)
assert len(annots) == 0
assert bads == ["0"]
# test multiple channels flat
raw_ = raw.copy()
raw_._data[0] = 0.0
raw_._data[2] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, **kwargs)
assert len(annots) == 0
assert bads == ["0", "2"]
# test entire channel drifting
raw_ = raw.copy()
raw_._data[0] = np.arange(0, raw.times.size * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs)
assert len(annots) == 0
assert bads == ["0"]
# test multiple channels drifting
raw_ = raw.copy()
raw_._data[0] = np.arange(0, raw.times.size * 10, 10)
raw_._data[2] = np.arange(0, raw.times.size * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs)
assert len(annots) == 0
assert bads == ["0", "2"]
# -- test bad channels temporal marking --
# flat channel for the 20% last points
n_good_times = int(round(0.8 * n_times))
raw_ = raw.copy()
raw_._data[0, n_good_times:] = 0.0
for perc in (5, 20):
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, bad_percent=perc)
assert len(annots) == 0
assert bads == ["0"]
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, bad_percent=20.1)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], meas_date, first_samp, n_good_times, -1)
# test multiple channels flat and multiple channels drift
raw_ = raw.copy()
raw_._data[0, 800:] = 0.0
raw_._data[1, 850:950] = 0.0
raw_._data[2, :200] = np.arange(0, 200 * 10, 10)
raw_._data[2, 200:] += raw_._data[2, 199] # add offset for next samples
raw_._data[3, 50:150] = np.arange(0, 100 * 10, 10)
raw_._data[3, 150:] += raw_._data[3, 149] # add offset for next samples
for perc in (5, 10):
annots, bads = annotate_amplitude(raw_, peak=5, flat=0.0, bad_percent=perc)
assert len(annots) == 0
assert bads == ["0", "1", "2", "3"]
for perc in (10.1, 20):
annots, bads = annotate_amplitude(raw_, peak=5, flat=0.0, bad_percent=perc)
assert len(annots) == 2
assert bads == ["0", "2"]
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 50 if annot["description"] == "BAD_peak" else 850
stop_idx = 149 if annot["description"] == "BAD_peak" else 949
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
annots, bads = annotate_amplitude(raw_, peak=5, flat=0.0, bad_percent=20.1)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 0 if annot["description"] == "BAD_peak" else 800
stop_idx = 199 if annot["description"] == "BAD_peak" else -1
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
# test flat on already marked bad channel
raw_ = raw.copy()
raw_._data[-1, :] = 0.0 # this channel is already in info['bads']
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.0, bad_percent=5)
assert len(annots) == 0
assert len(bads) == 0
# test drift on already marked bad channel
raw_ = raw.copy()
raw_._data[-1, :] = np.arange(0, raw.times.size * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=5)
assert len(annots) == 0
assert len(bads) == 0
@pytest.mark.parametrize("meas_date", (None, date))
@pytest.mark.parametrize("first_samp", (0, 10000))
def test_annotate_amplitude_with_overlap(meas_date, first_samp):
"""Test cases with overlap between annotations."""
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(n_ch, 1000.0, "eeg")
# from annotate_flat: test first_samp != for gh-6295
raw = RawArray(data, info, first_samp=first_samp)
raw.info["bads"] = [raw.ch_names[-1]]
raw.set_meas_date(meas_date)
# -- overlap between peak and flat --
raw_ = raw.copy()
raw_._data[0, 800:] = 0.0
raw_._data[1, 700:900] = np.arange(0, 200 * 10, 10)
raw_._data[1, 900:] += raw_._data[1, 899] # add offset for next samples
annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=25)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 700 if annot["description"] == "BAD_peak" else 800
stop_idx = 899 if annot["description"] == "BAD_peak" else -1
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
# -- overlap between peak and peak on same channel --
raw_ = raw.copy()
raw_._data[0, 700:900] = np.arange(0, 200 * 10, 10)
raw_._data[0, 800:] = np.arange(1000, 300 * 10, 10)
annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=50)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_peak"
_check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1)
# -- overlap between flat and flat on different channel --
raw_ = raw.copy()
raw_._data[0, 700:900] = 0.0
raw_._data[1, 800:] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0.01, bad_percent=50)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1)
@pytest.mark.parametrize("meas_date", (None, date))
@pytest.mark.parametrize("first_samp", (0, 10000))
def test_annotate_amplitude_multiple_ch_types(meas_date, first_samp):
"""Test cases with several channel types."""
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(
n_ch, 1000.0, ["eeg"] * 3 + ["mag"] * 2 + ["grad"] * 4 + ["eeg"] * 2
)
# from annotate_flat: test first_samp != for gh-6295
raw = RawArray(data, info, first_samp=first_samp)
raw.info["bads"] = [raw.ch_names[-1]]
raw.set_meas_date(meas_date)
# -- 2 channel types both to annotate --
raw_ = raw.copy()
raw_._data[1, 800:] = 0.0
raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples
annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
assert all(annot["description"] in ("BAD_flat", "BAD_peak") for annot in annots)
for annot in annots:
start_idx = 0 if annot["description"] == "BAD_peak" else 800
stop_idx = 199 if annot["description"] == "BAD_peak" else -1
_check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx)
# -- 2 channel types, one flat picked, one not picked --
raw_ = raw.copy()
raw_._data[1, 800:] = 0.0
raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples
annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50, picks="eeg")
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
_check_annotation(raw_, annots[0], meas_date, first_samp, 800, -1)
assert annots[0]["description"] == "BAD_flat"
# -- 2 channel types, one flat, one not picked, reverse --
raw_ = raw.copy()
raw_._data[1, 800:] = 0.0
raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples
annots, bads = annotate_amplitude(
raw_, peak=5, flat=0, bad_percent=50, picks="grad"
)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
_check_annotation(raw_, annots[0], meas_date, first_samp, 0, 199)
assert annots[0]["description"] == "BAD_peak"
@testing.requires_testing_data
def test_flat_bad_acq_skip():
"""Test that acquisition skips are handled properly."""
# -- file with a couple of skip and flat channels --
raw = read_raw_fif(skip_fname, preload=True)
annots, bads = annotate_amplitude(raw, flat=0)
assert len(annots) == 0
assert bads == [
f"MEG{num.zfill(4)}"
for num in "141 331 421 431 611 641 1011 1021 1031 1241 1421 "
"1741 1841 2011 2131 2141 2241 2531 2541 2611 2621".split()
] # MaxFilter finds the same 21 channels
# -- overlap of flat segment with bad_acq_skip --
n_ch, n_times = 11, 1000
data = np.random.RandomState(0).randn(n_ch, n_times)
assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first
info = create_info(n_ch, 1000.0, "eeg")
raw = RawArray(data, info, first_samp=0)
raw.info["bads"] = [raw.ch_names[-1]]
bad_acq_skip = Annotations([0.5], [0.2], ["bad_acq_skip"], orig_time=None)
raw.set_annotations(bad_acq_skip)
# add flat channel overlapping with the left edge of bad_acq_skip
raw_ = raw.copy()
raw_._data[0, 400:600] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], None, 0, 400, 499)
# add flat channel overlapping with the right edge of bad_acq_skip
raw_ = raw.copy()
raw_._data[0, 600:800] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25)
assert len(annots) == 1
assert len(bads) == 0
# check annotation instance
assert annots[0]["description"] == "BAD_flat"
_check_annotation(raw_, annots[0], None, 0, 700, 799)
# add flat channel overlapping entirely with bad_acq_skip
raw_ = raw.copy()
raw_._data[0, 200:800] = 0.0
annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=41)
assert len(annots) == 2
assert len(bads) == 0
# check annotation instance
annots = sorted(annots, key=lambda x: x["onset"])
assert all(annot["description"] == "BAD_flat" for annot in annots)
_check_annotation(raw_, annots[0], None, 0, 200, 500)
_check_annotation(raw_, annots[1], None, 0, 700, 799)
def _check_annotation(raw, annot, meas_date, first_samp, start_idx, stop_idx):
"""Util function to check an annotation."""
assert meas_date == annot["orig_time"]
if meas_date is None:
assert np.isclose(raw.times[start_idx], annot["onset"], atol=1e-4)
assert np.isclose(
raw.times[stop_idx], annot["onset"] + annot["duration"], atol=1e-4
)
else:
first_time = first_samp / raw.info["sfreq"] # because of meas_date
assert np.isclose(raw.times[start_idx], annot["onset"] - first_time, atol=1e-4)
assert np.isclose(
raw.times[stop_idx],
annot["onset"] + annot["duration"] - first_time,
atol=1e-4,
)
def test_invalid_arguments():
"""Test error messages raised by invalid arguments."""
n_ch, n_times = 2, 100
data = np.random.RandomState(0).randn(n_ch, n_times)
info = create_info(n_ch, 100.0, "eeg")
raw = RawArray(data, info, first_samp=0)
# negative floats PTP
with pytest.raises(
ValueError,
match="Argument 'flat' should define a positive threshold. Provided: '-1'.",
):
annotate_amplitude(raw, peak=None, flat=-1)
with pytest.raises(
ValueError,
match="Argument 'peak' should define a positive threshold. Provided: '-1'.",
):
annotate_amplitude(raw, peak=-1, flat=None)
# negative PTP threshold for one channel type
with pytest.raises(
ValueError,
match="Argument 'flat' should define positive "
"thresholds. Provided for channel type "
"'eog': '-1'.",
):
annotate_amplitude(raw, peak=None, flat=dict(eeg=1, eog=-1))
with pytest.raises(
ValueError,
match="Argument 'peak' should define positive "
"thresholds. Provided for channel type "
"'eog': '-1'.",
):
annotate_amplitude(raw, peak=dict(eeg=1, eog=-1), flat=None)
# test both PTP set to None
with pytest.raises(
ValueError,
match="At least one of the arguments 'peak' or 'flat' must not be None.",
):
annotate_amplitude(raw, peak=None, flat=None)
# bad_percent outside [0, 100]
with pytest.raises(
ValueError,
match="Argument 'bad_percent' should define a "
"percentage between 0% and 100%. Provided: "
"-1.0%.",
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, bad_percent=-1)
# min_duration negative
with pytest.raises(
ValueError,
match="Argument 'min_duration' should define a "
"positive duration in seconds. Provided: "
"'-1.0' seconds.",
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=-1)
# min_duration equal to the raw duration
with pytest.raises(
ValueError,
match=re.escape(
"Argument 'min_duration' should define a "
"positive duration in seconds shorter than the "
"raw duration (1.0 seconds). Provided: "
"'1.0' seconds."
),
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=1.0)
# min_duration longer than the raw duration
with pytest.raises(
ValueError,
match=re.escape(
"Argument 'min_duration' should define a "
"positive duration in seconds shorter than the "
"raw duration (1.0 seconds). Provided: "
"'10.0' seconds."
),
):
annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=10)