[074d3d]: / mne / preprocessing / tests / test_pca_obs.py

Download this file

108 lines (90 with data), 3.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import numpy as np
import pytest
from mne.io import read_raw_fif
from mne.io.fiff.raw import Raw
from mne.preprocessing import apply_pca_obs
data_path = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = data_path / "test_raw.fif"
@pytest.fixture()
def short_raw_data():
"""Create a short, picked raw instance."""
return read_raw_fif(raw_fname, preload=True)
def test_heart_artifact_removal(short_raw_data: Raw):
"""Test PCA-OBS analysis and heart artifact removal of ECG datasets."""
pd = pytest.importorskip("pandas")
# copy the original raw. heart artifact is removed in-place
orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True)
# fake some random qrs events in the window of the raw data
# remove first and last samples and cast to integer for indexing
ecg_event_times = np.linspace(0, orig_df["time"].iloc[-1], 20)[1:-1]
# perform heart artifact removal
short_raw_data = apply_pca_obs(
raw=short_raw_data, picks=["eeg"], qrs_times=ecg_event_times, n_jobs=1
)
# compare processed df to original df
removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame()
# ensure all column names remain the same
pd.testing.assert_index_equal(
orig_df.columns,
removed_heart_artifact_df.columns,
)
# ensure every column starting with EEG has been altered
altered_cols = [c for c in orig_df.columns if c.startswith("EEG")]
for col in altered_cols:
with pytest.raises(
AssertionError
): # make sure that error is raised when we check equal
pd.testing.assert_series_equal(
orig_df[col],
removed_heart_artifact_df[col],
)
# ensure every column not starting with EEG has not been altered
unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")]
pd.testing.assert_frame_equal(
orig_df[unaltered_cols],
removed_heart_artifact_df[unaltered_cols],
)
# test that various nonsensical inputs raise the proper errors
@pytest.mark.parametrize(
("picks", "qrs_times", "error", "exception"),
[
(
["eeg"],
np.array([[0, 1], [2, 3]]),
"qrs_times must be a 1d array",
ValueError,
),
(
["eeg"],
[2, 3, 4],
"qrs_times must be an instance of ndarray, got <class 'list'> instead.",
TypeError,
),
(
["eeg"],
np.array([None, "foo", 2]),
"qrs_times must be an array of either integers or floats",
ValueError,
),
(
["eeg"],
np.array([-1, 0, 3]),
"qrs_times must be strictly positive",
ValueError,
),
],
)
def test_pca_obs_bad_input(
short_raw_data: Raw,
picks: list[str],
qrs_times: np.ndarray,
error: str,
exception: type[Exception],
):
"""Test if bad input data raises the proper errors in the function sanity checks."""
with pytest.raises(exception, match=error):
apply_pca_obs(raw=short_raw_data, picks=picks, qrs_times=qrs_times)