[074d3d]: / mne / time_frequency / tests / test_spectrum.py

Download this file

715 lines (648 with data), 27.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import re
from functools import partial
import numpy as np
import pytest
from matplotlib.colors import same_color
from numpy.testing import assert_allclose, assert_array_equal
from mne import Annotations, BaseEpochs, create_info, make_fixed_length_epochs
from mne.io import RawArray
from mne.time_frequency import read_spectrum
from mne.time_frequency.multitaper import _psd_from_mt
from mne.time_frequency.spectrum import (
EpochsSpectrumArray,
SpectrumArray,
combine_spectrum,
)
from mne.utils import _record_warnings
def test_compute_psd_errors(raw):
"""Test for expected errors in the .compute_psd() method."""
with pytest.raises(ValueError, match="must not exceed ½ the sampling"):
raw.compute_psd(fmax=raw.info["sfreq"] * 0.51)
with pytest.raises(TypeError, match="unexpected keyword argument foo for"):
raw.compute_psd(foo=None)
with pytest.raises(TypeError, match="keyword arguments foo, bar for"):
raw.compute_psd(foo=None, bar=None)
raw.set_annotations(Annotations(onset=0.01, duration=0.01, description="bad_foo"))
with pytest.raises(NotImplementedError, match='Cannot use method="multitaper"'):
raw.compute_psd(method="multitaper", reject_by_annotation=True)
@pytest.mark.parametrize("method", ("welch", "multitaper"))
@pytest.mark.parametrize(
(
"fmin, fmax, tmin, tmax, picks, proj, n_fft, n_overlap, n_per_seg, "
"average, window, bandwidth, adaptive, low_bias, normalization, output"
),
[
[
0,
np.inf,
None,
None,
None,
False,
256,
0,
None,
"mean",
"hamming",
None,
False,
True,
"length",
"power",
], # defaults
[
5,
50,
1,
6,
"grad",
True,
128,
8,
32,
"median",
"triang",
10,
True,
False,
"full",
"power",
], # non-defaults (excluding output)
[
0,
np.inf,
None,
None,
None,
False,
256,
0,
None,
"mean",
"hamming",
None,
False,
True,
"length",
"complex",
], # complex (testing with non-defaults doesn't increase coverage)
],
)
def test_spectrum_params(
method,
fmin,
fmax,
tmin,
tmax,
picks,
proj,
n_fft,
n_overlap,
n_per_seg,
average,
window,
bandwidth,
adaptive,
low_bias,
normalization,
output,
raw,
):
"""Test valid parameter combinations in the .compute_psd() method."""
kwargs = dict(
method=method,
fmin=fmin,
fmax=fmax,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
output=output,
)
if method == "welch":
kwargs.update(
n_fft=n_fft,
n_overlap=n_overlap,
n_per_seg=n_per_seg,
average=average,
window=window,
)
else:
kwargs.update(
bandwidth=bandwidth,
adaptive=adaptive,
low_bias=low_bias,
normalization=normalization,
)
raw.compute_psd(**kwargs)
def test_n_welch_windows(raw):
"""Test computation of welch windows https://mne.discourse.group/t/5734."""
with raw.info._unlock():
raw.info["sfreq"] = 999.412109375
raw.compute_psd(
method="welch", n_fft=999, n_per_seg=999, n_overlap=250, average=None
)
def _get_inst(inst, request, *, evoked=None, average_tfr=None):
# ↓ XXX workaround:
# ↓ parametrized fixtures are not accessible via request.getfixturevalue
# ↓ https://github.com/pytest-dev/pytest/issues/4666#issuecomment-456593913
if inst == "evoked":
return evoked
elif inst == "average_tfr":
return average_tfr
return request.getfixturevalue(inst)
@pytest.mark.parametrize("inst", ("raw", "epochs_full", "evoked"))
def test_spectrum_io(inst, tmp_path, request, evoked):
"""Test save/load of spectrum objects."""
pytest.importorskip("h5io")
fname = tmp_path / f"{inst}-spectrum.h5"
inst = _get_inst(inst, request, evoked=evoked)
if isinstance(inst, BaseEpochs):
# fake HED-like tags (https://mne.discourse.group/t/10634)
inst.events[-2:, -1] = 2
inst.event_id = {"foo/bar": 1, "foo/qux": 2}
orig = inst.compute_psd()
if isinstance(inst, BaseEpochs):
orig = orig["foo"]
orig.save(fname)
loaded = read_spectrum(fname)
assert orig == loaded
def test_spectrum_copy(raw_spectrum):
"""Test copying Spectrum objects."""
spect_copy = raw_spectrum.copy()
assert raw_spectrum == spect_copy
assert id(raw_spectrum) != id(spect_copy)
spect_copy._freqs = None
assert raw_spectrum.freqs is not None
@pytest.mark.parametrize("weights", ["nave", "equal", [1, -1]])
def test_combine_spectrum(raw_spectrum, weights):
"""Test `combine_spectrum()` works."""
spectrum1 = raw_spectrum.copy()
spectrum2 = raw_spectrum.copy()
if weights == "nave":
spectrum1._nave = 1
spectrum2._nave = 2
spectrum2._data *= 2
new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights)
assert_allclose(new_spectrum.data, spectrum1.data * (5 / 3))
elif weights == "equal":
spectrum2._data *= 2
new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights)
assert_allclose(new_spectrum.data, spectrum1.data * 1.5)
else:
new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights)
assert_allclose(new_spectrum.data, 0)
def test_combine_spectrum_error_catch(raw_spectrum):
"""Test `combine_spectrum()` catches errors."""
# Test bad weights
with pytest.raises(
ValueError, match='Weights must be a list of float, or "nave" or "equal"'
):
combine_spectrum([raw_spectrum, raw_spectrum], weights="foo")
with pytest.raises(
ValueError, match="Weights must be the same size as all_spectrum"
):
combine_spectrum([raw_spectrum, raw_spectrum], weights=[1, 1, 1])
# Test bad nave
with pytest.raises(ValueError, match="The 'nave' attribute is not specified"):
combine_spectrum([raw_spectrum, raw_spectrum], weights="nave")
# Test inconsistent channels
raw_spectrum2 = raw_spectrum.copy()
raw_spectrum2.drop_channels(raw_spectrum2.ch_names[0])
with pytest.raises(AssertionError, match=".* do not contain the same channels"):
combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal")
# Test inconsistent frequencies
raw_spectrum2 = raw_spectrum.copy()
raw_spectrum2._freqs = raw_spectrum2._freqs + 1
with pytest.raises(AssertionError, match=".* do not contain the same frequencies"):
combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal")
def test_spectrum_reject_by_annot(raw):
"""Test rejecting by annotation.
Cannot use raw_spectrum fixture here because we're testing reject_by_annotation in
.compute_psd() method.
"""
kw = dict(n_per_seg=512) # smaller than shortest good span, to avoid warning
spect_no_annot = raw.compute_psd(**kw)
raw.set_annotations(Annotations([1, 5], [3, 3], ["test", "test"]))
spect_benign_annot = raw.compute_psd(**kw)
raw.annotations.description = np.array(["bad_test", "bad_test"])
spect_reject_annot = raw.compute_psd(**kw)
spect_ignored_annot = raw.compute_psd(**kw, reject_by_annotation=False)
# the only one that should be different is `spect_reject_annot`
assert spect_no_annot == spect_benign_annot
assert spect_no_annot == spect_ignored_annot
assert spect_no_annot != spect_reject_annot
def test_spectrum_bads_exclude(raw):
"""Test bads are not removed unless exclude="bads"."""
raw.pick("mag") # get rid of IAS channel
spect_no_excld = raw.compute_psd()
spect_with_excld = raw.compute_psd(exclude="bads")
assert raw.info["bads"] == spect_no_excld.info["bads"]
assert spect_with_excld.info["bads"] == []
assert set(raw.ch_names) - set(spect_with_excld.ch_names) == set(raw.info["bads"])
def test_spectrum_getitem_raw(raw_spectrum):
"""Test Spectrum.__getitem__ for Raw-derived spectra."""
want = raw_spectrum.get_data(slice(1, 3), fmax=7)
freq_idx = np.searchsorted(raw_spectrum.freqs, 7)
got = raw_spectrum[1:3, :freq_idx]
assert_array_equal(want, got)
def test_spectrum_getitem_epochs(epochs_spectrum):
"""Test Spectrum.__getitem__ for Epochs-derived spectra."""
# testing data has just one epoch, its event_id label is "1"
want = epochs_spectrum.get_data()
got = epochs_spectrum["1"].get_data()
assert_array_equal(want, got)
@pytest.mark.parametrize("method", ("mean", partial(np.std, axis=0)))
def test_epochs_spectrum_average(epochs_spectrum, method):
"""Test EpochsSpectrum.average()."""
avg_spect = epochs_spectrum.average(method=method)
assert avg_spect.shape == epochs_spectrum.shape[1:]
assert avg_spect._dims == ("channel", "freq") # no 'epoch'
@pytest.mark.parametrize("inst", ("raw_spectrum", "epochs_spectrum", "evoked"))
def test_spectrum_to_data_frame(inst, request, evoked):
"""Test the to_data_frame method for Spectrum."""
pytest.importorskip("pandas")
from pandas.testing import assert_frame_equal
# setup
is_already_psd = inst in ("raw_spectrum", "epochs_spectrum")
is_epochs = inst == "epochs_spectrum"
inst = _get_inst(inst, request, evoked=evoked)
extra_dim = () if is_epochs else (1,)
extra_cols = ["freq", "condition", "epoch"] if is_epochs else ["freq"]
# compute PSD
spectrum = inst if is_already_psd else inst.compute_psd(exclude="bads")
n_epo, n_chan, n_freq = extra_dim + spectrum.get_data().shape
# test wide format
df_wide = spectrum.to_data_frame()
n_row, n_col = df_wide.shape
assert n_row == n_freq
assert n_col == n_chan + len(extra_cols)
assert set(spectrum.ch_names + extra_cols) == set(df_wide.columns)
# test long format
df_long = spectrum.to_data_frame(long_format=True)
n_row, n_col = df_long.shape
assert n_row == n_epo * n_freq * n_chan
base_cols = ["channel", "ch_type", "value"]
assert n_col == len(base_cols + extra_cols)
assert set(base_cols + extra_cols) == set(df_long.columns)
# test index
index = extra_cols[-2:] # ['freq'] or ['condition', 'epoch']
df = spectrum.to_data_frame(index=index)
if is_epochs:
index_tuple = (
list(spectrum.event_id)[0], # condition
spectrum.selection[0],
) # epoch number
subset = df.loc[index_tuple]
assert subset.shape == (n_freq, n_chan + 1) # + 1 is the freq column
with pytest.raises(ValueError, match='"time" is not a valid option'):
spectrum.to_data_frame(index="time")
# test picks
picks = [0, 1]
_pick_first = spectrum.pick(picks).to_data_frame()
_pick_last = spectrum.to_data_frame(picks=picks)
assert_frame_equal(_pick_first, _pick_last)
def _agg_helper(df, weights, group_cols):
"""Aggregate complex multitaper spectrum after conversion to DataFrame."""
from pandas import Series
unagged_columns = df[group_cols].iloc[0].values.tolist()
x_mt = df.drop(columns=group_cols).values[np.newaxis].T
psd = _psd_from_mt(x_mt, weights)
psd = np.atleast_1d(np.squeeze(psd)).tolist()
_df = dict(zip(df.columns, unagged_columns + psd))
return Series(_df)
@pytest.mark.parametrize("long_format", (False, True))
@pytest.mark.parametrize(
"method, output",
[("welch", "complex"), ("welch", "power"), ("multitaper", "complex")],
)
def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output):
"""Test converting unaggregated spectra (multiple segments/tapers) to data frame."""
pytest.importorskip("pandas")
from pandas.testing import assert_frame_equal
from mne.utils.dataframe import _inplace
# aggregated spectrum → dataframe
orig_df = raw.compute_psd(method=method).to_data_frame(long_format=long_format)
# unaggregated welch or complex multitaper →
# aggregate w/ pandas (to make sure we did reshaping right)
kwargs = dict()
if method == "welch":
kwargs.update(average=False)
spectrum = raw.compute_psd(method=method, output=output, **kwargs)
df = spectrum.to_data_frame(long_format=long_format)
grouping_cols = ["freq"]
drop_cols = ["segment"] if method == "welch" else ["taper"]
if long_format:
grouping_cols.append("channel")
drop_cols.append("ch_type")
orig_df.drop(columns="ch_type", inplace=True)
# only do a couple freq bins, otherwise test takes forever for multitaper
subset = partial(np.isin, test_elements=spectrum.freqs[:2])
df = df.loc[subset(df["freq"])]
orig_df = orig_df.loc[subset(orig_df["freq"])]
# sort orig_df, because at present we can't actually prevent pandas from
# sorting at the agg step *sigh*
_inplace(orig_df, "sort_values", by=grouping_cols, ignore_index=True)
# aggregate
df = df.drop(columns=drop_cols)
gb = df.groupby(grouping_cols, as_index=False, observed=False)
if method == "welch":
if output == "complex":
def _fun(x):
return np.mean(np.real(x * np.conj(x))) # use mean to aggregate
agg_df = gb.agg(_fun)
else:
agg_df = gb.mean() # excludes missing values itself
else:
gb = gb[df.columns] # XXX: try removing when minimum pandas >= 2.1 is required
agg_df = gb.apply(_agg_helper, spectrum.weights, grouping_cols)
# even with check_categorical=False, we know that the *data* matches;
# what may differ is the order of the "levels" in the *metadata* for the
# channel name column
agg_df.sort_values(by=grouping_cols, ignore_index=True, inplace=True)
orig_df.sort_values(by=grouping_cols, ignore_index=True, inplace=True)
# One can have categorical dtype and the other plain object, so don't check that
assert_frame_equal(agg_df, orig_df, check_categorical=False, check_dtype=False)
# not testing with Evoked because it already has projs applied
@pytest.mark.parametrize("inst", ("raw", "epochs"))
def test_spectrum_proj(inst, request):
"""Test that proj is applied correctly (gh 11177)."""
inst = request.getfixturevalue(inst)
has_proj = inst.compute_psd(proj=True)
no_proj = inst.compute_psd(proj=False)
assert not np.array_equal(has_proj.get_data(), no_proj.get_data())
# make sure only the data (and the projs) were different
has_proj._data = no_proj._data
with has_proj.info._unlock():
has_proj.info["projs"] = no_proj.info["projs"]
assert has_proj == no_proj
@pytest.mark.parametrize(
"method, average", [("welch", False), ("welch", "mean"), ("multitaper", None)]
)
def test_spectrum_complex(method, average):
"""Test output='complex' support."""
sfreq = 100
n = 10 * sfreq
freq = 3.0
phase = np.pi / 4 # should be recoverable
data = np.cos(2 * np.pi * freq * np.arange(n) / sfreq + phase)[np.newaxis]
raw = RawArray(data, create_info(1, sfreq, "eeg"))
epochs = make_fixed_length_epochs(raw, duration=2.0, preload=True)
assert len(epochs) == 5
assert len(epochs.times) == 2 * sfreq
kwargs = dict(output="complex", method=method)
if method == "welch":
kwargs["n_fft"] = sfreq
want_dims = ("epoch", "channel", "freq")
want_shape = (5, 1, sfreq // 2 + 1)
if not average:
want_dims = want_dims + ("segment",)
want_shape = want_shape + (2,)
kwargs["average"] = average
else:
assert method == "multitaper"
assert not average
want_dims = ("epoch", "channel", "taper", "freq")
want_shape = (5, 1, 7, sfreq + 1)
spectrum = epochs.compute_psd(**kwargs)
idx = np.argmin(np.abs(spectrum.freqs - freq))
assert spectrum.freqs[idx] == freq
assert spectrum._dims == want_dims
assert spectrum.shape == want_shape
data = spectrum.get_data()
assert data.dtype == np.complex128
coef = spectrum.get_data(fmin=freq, fmax=freq).mean(0)
if method == "multitaper":
coef = coef[..., 0, :] # first taper
elif not average:
coef = coef.mean(-1) # over segments
coef = coef.item()
# Test phase matches what was simulated
assert_allclose(np.angle(coef), phase, rtol=1e-4)
# Now test that it warns appropriately
epochs._data[0, 0, :] = 0 # actually zero for one epoch and ch
with pytest.warns(UserWarning, match="Zero value.*channel 0"):
epochs.compute_psd(**kwargs)
# But not if we mark that channel as bad
epochs.info["bads"] = epochs.ch_names[:1]
epochs.compute_psd(**kwargs)
def test_spectrum_kwarg_triaging(raw):
"""Test kwarg triaging in legacy plot_psd() method."""
import matplotlib.pyplot as plt
regex = r"legacy plot_psd\(\) method.*unexpected keyword.*'axes'.*Try rewriting"
_, axes = plt.subplots(1, 2)
# `axes` is the new param name: technically only valid for Spectrum.plot()
with _record_warnings(), pytest.warns(RuntimeWarning, match=regex):
raw.plot_psd(axes=axes)
# `ax` is the correct legacy param name
raw.plot_psd(ax=axes)
def _check_spectrum_equivalent(spect1, spect2, tmp_path):
data1 = spect1.get_data()
data2 = spect2.get_data()
assert_array_equal(data1, data2)
assert_array_equal(spect1.freqs, spect2.freqs)
def test_spectrum_array_errors():
"""Test (Epochs)SpectrumArray constructor errors."""
n_epochs = 10
n_chans = 5
n_freqs = 50
freqs = np.arange(n_freqs)
sfreq = 100
rng = np.random.default_rng(44)
data = rng.random((n_epochs, n_chans, n_freqs))
dim_names = ("epoch", "channel", "freq")
info = create_info(n_chans, sfreq, "eeg")
# test incorrect ndims (for SpectrumArray; allows 2-3D data)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
SpectrumArray(data[0, 0, :], info, freqs, dim_names=dim_names)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names)
# test incorrect ndims (for EpochsSpectrumArray; allows 3-4D data)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
EpochsSpectrumArray(data[0, :, :], info, freqs, dim_names=dim_names)
with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"):
EpochsSpectrumArray(
np.expand_dims(data, axis=(3, 4)), info, freqs, dim_names=dim_names
)
# test incorrect epochs location
with pytest.raises(ValueError, match="'epoch' must be the first dimension"):
EpochsSpectrumArray(data, info, freqs, dim_names=("channel", "epoch", "freq"))
# test mismatching events shape
events = np.vstack(
(
np.arange(n_epochs + 1),
np.zeros(n_epochs + 1, dtype=int),
np.ones(n_epochs + 1, dtype=int),
)
).T
with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"):
EpochsSpectrumArray(data, info, freqs, events, dim_names=dim_names)
# test data-dimname mismatch
with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"):
EpochsSpectrumArray(data, info, freqs, dim_names=dim_names[:-1])
# test unrecognised dim_names (for SpectrumArray; epoch not allowed)
with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"):
SpectrumArray(data[0, :, :], info, freqs, dim_names=("epoch", "channel"))
# test unrecognised dim_names (for EpochsSpectrumArray)
with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"):
EpochsSpectrumArray(
data, info, freqs, dim_names=("epoch", "channel", "notfreq")
)
# test missing dim_names
with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"):
EpochsSpectrumArray(
data, info, freqs, dim_names=("epoch", "channel", "channel")
)
with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"):
EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "freq"))
with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"):
EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "epoch", "epoch"))
# test incorrect channel location (for SpectrumArray; must be 1st dim)
with pytest.raises(ValueError, match="'channel' must be the first dimension"):
SpectrumArray(data[0, :, :], info, freqs, dim_names=("freq", "channel"))
# test incorrect channel location (for EpochsSpectrumArray; must be 2nd dim)
with pytest.raises(ValueError, match="'channel' must be the second dimension"):
EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "channel"))
# test mismatching number of channels
with pytest.raises(
ValueError, match=re.escape("number of good + bad data channels")
):
EpochsSpectrumArray(data[:, :-1, :], info, freqs, dim_names=dim_names)
# test incorrect taper position
with pytest.raises(ValueError, match="'taper' must be the second to last dim"):
EpochsSpectrumArray(
np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names + ("taper",)
)
# test incorrect weight size
with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"):
EpochsSpectrumArray(
np.expand_dims(data, axis=2),
info,
freqs,
dim_names=("epoch", "channel", "taper", "freq"),
weights=None,
)
with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"):
EpochsSpectrumArray(
np.expand_dims(data, axis=2),
info,
freqs,
dim_names=("epoch", "channel", "taper", "freq"),
weights=np.ones((1, 2, 1)),
)
# test incorrect segment position
with pytest.raises(ValueError, match="'segment' must be the last dim"):
EpochsSpectrumArray(
np.expand_dims(data, axis=2),
info,
freqs,
dim_names=("epoch", "channel", "segment", "freq"),
)
# test mismatching number of frequencies
with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"):
EpochsSpectrumArray(data[:, :, :-1], info, freqs, dim_names=dim_names)
@pytest.mark.parametrize(
"kind, method, output, average",
[
("raw", "welch", "power", "mean"), # test with precomputed spectrum
("epochs", "welch", "power", False), # test with segments
("epochs", "multitaper", "complex", None), # test with tapers
], # additional variants don't improve coverage
)
def test_spectrum_array(kind, method, output, average, tmp_path, request):
"""Test EpochsSpectrumArray and SpectrumArray constructors."""
dim_names = ("epoch", "channel") if kind == "epochs" else ("channel",)
if method == "welch":
dim_names += ("freq",) if average else ("freq", "segment")
else: # i.e. multitaper
dim_names += ("freq",) if output == "power" else ("taper", "freq")
if method == "welch" and output == "power" and average:
spectrum = request.getfixturevalue(f"{kind}_spectrum")
else:
data = request.getfixturevalue(kind)
kwargs = dict()
if method == "welch":
kwargs.update(average=average)
spectrum = data.compute_psd(method=method, output=output, **kwargs)
data, freqs = spectrum.get_data(return_freqs=True)
Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray
spect_arr = Klass(
data=data,
info=spectrum.info,
freqs=freqs,
dim_names=dim_names,
weights=spectrum.weights,
)
_check_spectrum_equivalent(spectrum, spect_arr, tmp_path)
@pytest.mark.parametrize(
"method, output, average",
[
("welch", "power", "mean"), # test with precomputed spectrum
("welch", "complex", False), # test aggr over segments & conversion to power
("multitaper", "complex", None), # test aggr over tapers & conversion to power
], # additional variants don't improve coverage
)
def test_plot_spectrum(method, output, average, request):
"""Test plotting EpochsSpectrum(Array).
Testing Spectrum(Array) with raw data doesn't improve coverage.
"""
if method == "welch" and output == "power" and average:
spectrum = request.getfixturevalue("epochs_spectrum")
else:
data = request.getfixturevalue("epochs")
kwargs = dict()
if method == "welch":
kwargs.update(average=average)
spectrum = data.compute_psd(method=method, output=output, **kwargs)
spectrum.info["bads"] = spectrum.ch_names[:1] # one grad channel
spectrum.plot(average=True, amplitude=True, spatial_colors=True)
spectrum.plot(average=True, amplitude=False, spatial_colors=False)
n_grad = sum(ch_type == "grad" for ch_type in spectrum.get_channel_types())
for amp, sc in ((True, True), (False, False)):
fig = spectrum.plot(average=False, amplitude=amp, spatial_colors=sc, exclude=())
lines = fig.axes[0].lines[2:] # grads, ignore two vlines
assert len(lines) == n_grad
bad_color = "0.5" if sc else "r"
n_bad = sum(same_color(line.get_color(), bad_color) for line in lines)
assert n_bad == 1
spectrum.plot_topo()
spectrum.plot_topomap()
def test_plot_spectrum_array_with_bads():
"""Test plotting a spectrum array with bads."""
raw = RawArray(np.random.randn(3, 1000), create_info(3, 1000, "eeg"))
raw.info["bads"] = [raw.ch_names[1]]
spectrum = raw.compute_psd()
with pytest.raises(
ValueError, match=re.escape("number of good + bad data channels")
):
SpectrumArray(spectrum.get_data(), spectrum.info, spectrum.freqs)
spectrum2 = SpectrumArray(
spectrum.get_data(exclude=()), spectrum.info, spectrum.freqs
)
spectrum2.plot(spatial_colors=False)
@pytest.mark.parametrize("dB", (False, True))
@pytest.mark.parametrize("amplitude", (False, True))
def test_plot_spectrum_dB(raw_spectrum, dB, amplitude):
"""Test that we properly handle amplitude/power and dB."""
idx = 7
power = 3
freqs = np.linspace(1, 100, 100)
data = np.full((1, freqs.size), np.finfo(float).tiny)
data[0, idx] = power
info = create_info(ch_names=["delta"], sfreq=1000, ch_types="eeg")
psd = SpectrumArray(data=data, info=info, freqs=freqs)
with pytest.warns(RuntimeWarning, match="Channel locations not available"):
fig = psd.plot(dB=dB, amplitude=amplitude)
trace = list(
filter(lambda x: len(x.get_data()[0]) == len(freqs), fig.axes[0].lines)
)[0]
got = trace.get_data()[1][idx]
want = power * 1e12 # scaling for EEG (V → μV), squared
if amplitude:
want = np.sqrt(want)
if dB:
want = (20 if amplitude else 10) * np.log10(want)
assert want == got, f"expected {want}, got {got}"