[7f9fb8]: / mne / preprocessing / stim.py

Download this file

177 lines (156 with data), 5.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy.interpolate import interp1d
from scipy.signal.windows import hann
from .._fiff.pick import _picks_to_idx
from ..epochs import BaseEpochs
from ..event import find_events
from ..evoked import Evoked
from ..io import BaseRaw
from ..utils import _check_option, _check_preload, _validate_type, fill_doc
def _get_window(start, end):
"""Return window which has length as much as parameter start - end."""
window = 1 - np.r_[hann(4)[:2], np.ones(np.abs(end - start) - 4), hann(4)[-2:]].T
return window
def _fix_artifact(
data, window, picks, first_samp, last_samp, base_tmin, base_tmax, mode
):
"""Modify original data by using parameter data."""
if mode == "linear":
x = np.array([first_samp, last_samp])
f = interp1d(x, data[:, (first_samp, last_samp)][picks])
xnew = np.arange(first_samp, last_samp)
interp_data = f(xnew)
data[picks, first_samp:last_samp] = interp_data
if mode == "window":
data[picks, first_samp:last_samp] = (
data[picks, first_samp:last_samp] * window[np.newaxis, :]
)
if mode == "constant":
data[picks, first_samp:last_samp] = data[picks, base_tmin:base_tmax].mean(
axis=1
)[:, None]
@fill_doc
def fix_stim_artifact(
inst,
events=None,
event_id=None,
tmin=0.0,
tmax=0.01,
*,
baseline=None,
mode="linear",
stim_channel=None,
picks=None,
):
"""Eliminate stimulation's artifacts from instance.
.. note:: This function operates in-place, consider passing
``inst.copy()`` if this is not desired.
Parameters
----------
inst : instance of Raw or Epochs or Evoked
The data.
events : array, shape (n_events, 3)
The list of events. Required only when inst is Raw.
event_id : int
The id of the events generating the stimulation artifacts.
If None, read all events. Required only when inst is Raw.
tmin : float
Start time of the interpolation window in seconds.
tmax : float
End time of the interpolation window in seconds.
baseline : None | tuple, shape (2,)
The baseline to use when ``mode='constant'``, in which case it
must be non-None.
.. versionadded:: 1.8
mode : 'linear' | 'window' | 'constant'
Way to fill the artifacted time interval.
``"linear"``
Does linear interpolation.
``"window"``
Applies a ``(1 - hanning)`` window.
``"constant"``
Uses baseline average. baseline parameter must be provided.
.. versionchanged:: 1.8
Added the ``"constant"`` mode.
stim_channel : str | None
Stim channel to use.
%(picks_all_data)s
Returns
-------
inst : instance of Raw or Evoked or Epochs
Instance with modified data.
"""
_check_option("mode", mode, ["linear", "window", "constant"])
s_start = int(np.ceil(inst.info["sfreq"] * tmin))
s_end = int(np.ceil(inst.info["sfreq"] * tmax))
if mode == "constant":
_validate_type(
baseline, (tuple, list), "baseline", extra="when mode='constant'"
)
_check_option("len(baseline)", len(baseline), [2])
for bi, b in enumerate(baseline):
_validate_type(
b, "numeric", f"baseline[{bi}]", extra="when mode='constant'"
)
b_start = int(np.ceil(inst.info["sfreq"] * baseline[0]))
b_end = int(np.ceil(inst.info["sfreq"] * baseline[1]))
else:
b_start = b_end = np.nan
if (mode == "window") and (s_end - s_start) < 4:
raise ValueError(
'Time range is too short. Use a larger interval or set mode to "linear".'
)
window = None
if mode == "window":
window = _get_window(s_start, s_end)
picks = _picks_to_idx(inst.info, picks, "data", exclude=())
_check_preload(inst, "fix_stim_artifact")
if isinstance(inst, BaseRaw):
if events is None:
events = find_events(inst, stim_channel=stim_channel)
if len(events) == 0:
raise ValueError("No events are found")
if event_id is None:
events_sel = np.arange(len(events))
else:
events_sel = events[:, 2] == event_id
event_start = events[events_sel, 0]
data = inst._data
for event_idx in event_start:
first_samp = int(event_idx) - inst.first_samp + s_start
last_samp = int(event_idx) - inst.first_samp + s_end
base_t1 = int(event_idx) - inst.first_samp + b_start
base_t2 = int(event_idx) - inst.first_samp + b_end
_fix_artifact(
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)
elif isinstance(inst, BaseEpochs):
if inst.reject is not None:
raise RuntimeError(
"Reject is already applied. Use reject=None in the constructor."
)
e_start = int(np.ceil(inst.info["sfreq"] * inst.tmin))
first_samp = s_start - e_start
last_samp = s_end - e_start
data = inst._data
base_t1 = b_start - e_start
base_t2 = b_end - e_start
for epoch in data:
_fix_artifact(
epoch, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)
elif isinstance(inst, Evoked):
first_samp = s_start - inst.first
last_samp = s_end - inst.first
data = inst.data
base_t1 = b_start - inst.first
base_t2 = b_end - inst.first
_fix_artifact(
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)
else:
raise TypeError(f"Not a Raw or Epochs or Evoked (got {type(inst)}).")
return inst