[074d3d]: / mne / preprocessing / tests / test_stim.py

Download this file

155 lines (131 with data), 5.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from mne.epochs import Epochs
from mne.event import read_events
from mne.io import read_raw_fif
from mne.preprocessing.stim import fix_stim_artifact
data_path = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = data_path / "test_raw.fif"
event_fname = data_path / "test-eve.fif"
def test_fix_stim_artifact():
"""Test fix stim artifact."""
events = read_events(event_fname)
raw = read_raw_fif(raw_fname)
pytest.raises(RuntimeError, fix_stim_artifact, raw)
raw = read_raw_fif(raw_fname, preload=True)
# use window before stimulus in epochs
tmin, tmax, event_id = -0.2, 0.5, 1
picks = ("meg", "eeg", "eog")
epochs = Epochs(
raw, events, event_id, tmin, tmax, picks=picks, preload=True, reject=None
)
e_start = int(np.ceil(epochs.info["sfreq"] * epochs.tmin))
tmin, tmax = -0.045, -0.015
tmin_samp = int(-0.035 * epochs.info["sfreq"]) - e_start
tmax_samp = int(-0.015 * epochs.info["sfreq"]) - e_start
epochs = fix_stim_artifact(
epochs, tmin=tmin, tmax=tmax, mode="linear", picks=("eeg", "eog")
)
data = epochs.get_data(("eeg", "eog"))[:, :, tmin_samp:tmax_samp]
diff_data0 = np.diff(data[0][0])
diff_data0 -= np.mean(diff_data0)
assert_array_almost_equal(diff_data0, np.zeros(len(diff_data0)))
data = epochs.get_data("meg")[:, :, tmin_samp:tmax_samp]
diff_data0 = np.diff(data[0][0])
diff_data0 -= np.mean(diff_data0)
assert np.all(diff_data0 != 0)
epochs = fix_stim_artifact(epochs, tmin=tmin, tmax=tmax, mode="window")
data_from_epochs_fix = epochs.get_data(copy=False)[:, :, tmin_samp:tmax_samp]
assert not np.all(data_from_epochs_fix != 0)
baseline = (-0.1, -0.05)
epochs = fix_stim_artifact(
epochs, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
)
b_start = int(np.ceil(epochs.info["sfreq"] * baseline[0]))
b_end = int(np.ceil(epochs.info["sfreq"] * baseline[1]))
base_t1 = b_start - e_start
base_t2 = b_end - e_start
baseline_mean = epochs.get_data()[:, :, base_t1:base_t2].mean(axis=2)[0][0]
data = epochs.get_data()[:, :, tmin_samp:tmax_samp]
assert data[0][0][0] == baseline_mean
# use window before stimulus in raw
event_idx = np.where(events[:, 2] == 1)[0][0]
tmin, tmax = -0.045, -0.015
tmin_samp = int(-0.035 * raw.info["sfreq"])
tmax_samp = int(-0.015 * raw.info["sfreq"])
tidx = int(events[event_idx, 0] - raw.first_samp)
pytest.raises(ValueError, fix_stim_artifact, raw, events=np.array([]))
raw = fix_stim_artifact(
raw,
events=None,
event_id=1,
tmin=tmin,
tmax=tmax,
mode="linear",
stim_channel="STI 014",
)
data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
diff_data0 = np.diff(data[0])
diff_data0 -= np.mean(diff_data0)
assert_array_almost_equal(diff_data0, np.zeros(len(diff_data0)))
raw = fix_stim_artifact(
raw, events, event_id=1, tmin=tmin, tmax=tmax, mode="window"
)
data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
assert np.all(data) == 0.0
raw = fix_stim_artifact(
raw,
events,
event_id=1,
tmin=tmin,
tmax=tmax,
baseline=baseline,
mode="constant",
)
data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
baseline_mean, _ = raw[:, (tidx + b_start) : (tidx + b_end)]
assert baseline_mean.mean(axis=1)[0] == data[0][0]
# get epochs from raw with fixed data
tmin, tmax, event_id = -0.2, 0.5, 1
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks,
preload=True,
reject=None,
baseline=None,
)
e_start = int(np.ceil(epochs.info["sfreq"] * epochs.tmin))
tmin_samp = int(-0.035 * epochs.info["sfreq"]) - e_start
tmax_samp = int(-0.015 * epochs.info["sfreq"]) - e_start
data_from_raw_fix = epochs.get_data(copy=False)[:, :, tmin_samp:tmax_samp]
assert np.all(data_from_raw_fix) == 0.0
# use window after stimulus
evoked = epochs.average()
tmin, tmax = 0.005, 0.045
tmin_samp = int(0.015 * evoked.info["sfreq"]) - evoked.first
tmax_samp = int(0.035 * evoked.info["sfreq"]) - evoked.first
evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode="linear")
data = evoked.data[:, tmin_samp:tmax_samp]
diff_data0 = np.diff(data[0])
diff_data0 -= np.mean(diff_data0)
assert_array_almost_equal(diff_data0, np.zeros(len(diff_data0)))
evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode="window")
data = evoked.data[:, tmin_samp:tmax_samp]
assert np.all(data) == 0.0
evoked = fix_stim_artifact(
evoked, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
)
base_t1 = int(baseline[0] * evoked.info["sfreq"]) - evoked.first
base_t2 = int(baseline[1] * evoked.info["sfreq"]) - evoked.first
data = evoked.data[:, tmin_samp:tmax_samp]
baseline_mean = evoked.data[:, base_t1:base_t2].mean(axis=1)[0]
assert data[0][0] == baseline_mean