[074d3d]: / mne / utils / tests / test_progressbar.py

Download this file

131 lines (109 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from pathlib import Path
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from mne.parallel import parallel_func
from mne.utils import ProgressBar, array_split_idx, catch_logging, use_log_level
def test_progressbar(monkeypatch):
"""Test progressbar class."""
a = np.arange(10)
pbar = ProgressBar(a)
assert a is pbar.iterable
assert pbar.max_value == 10
pbar = ProgressBar(10)
assert pbar.max_value == 10
assert pbar.iterable is None
# Make sure that non-iterable input raises an error
def iter_func(a):
for ii in a:
pass
with pytest.raises(TypeError, match="not iterable"):
iter_func(pbar)
# Make sure different progress bars can be used
monkeypatch.setenv("MNE_TQDM", "tqdm")
with catch_logging("debug") as log, ProgressBar(np.arange(3)) as pbar:
for p in pbar:
pass
log = log.getvalue()
assert "Using ProgressBar with tqdm\n" in log
monkeypatch.setenv("MNE_TQDM", "broken")
with pytest.raises(ValueError, match="Invalid value for the"):
ProgressBar(np.arange(3))
monkeypatch.setenv("MNE_TQDM", "tqdm.broken")
with pytest.raises(ValueError, match="Unknown tqdm"):
ProgressBar(np.arange(3))
# off
monkeypatch.setenv("MNE_TQDM", "off")
with catch_logging("debug") as log, ProgressBar(np.arange(3)) as pbar:
for p in pbar:
pass
log = log.getvalue()
assert "Using ProgressBar with off\n" == log
def _identity(x):
return x
def test_progressbar_parallel_basic(capsys):
"""Test ProgressBar with parallel computing, basic version."""
assert capsys.readouterr().out == ""
parallel, p_fun, _ = parallel_func(_identity, total=10, n_jobs=1, verbose=True)
with use_log_level(True):
out = parallel(p_fun(x) for x in range(10))
assert out == list(range(10))
cap = capsys.readouterr()
out = cap.err
assert "100%" in out
def _identity_block(x, pb):
for ii in range(len(x)):
pb.update(ii + 1)
return x
def test_progressbar_parallel_advanced(capsys):
"""Test ProgressBar with parallel computing, advanced version."""
assert capsys.readouterr().out == ""
# This must be "1" because "capsys" won't get stdout properly otherwise
parallel, p_fun, _ = parallel_func(_identity_block, n_jobs=1, verbose=False)
arr = np.arange(10)
with use_log_level(True):
with ProgressBar(len(arr)) as pb:
out = parallel(
p_fun(x, pb.subset(pb_idx)) for pb_idx, x in array_split_idx(arr, 2)
)
assert Path(pb._mmap_fname).is_file()
sum_ = np.memmap(pb._mmap_fname, dtype="bool", mode="r", shape=10).sum()
assert sum_ == len(arr)
assert not Path(pb._mmap_fname).is_file(), "__exit__ not called?"
out = np.concatenate(out)
assert_array_equal(out, arr)
cap = capsys.readouterr()
out = cap.err
assert "100%" in out
def _identity_block_wide(x, pb):
for ii in range(len(x)):
for jj in range(2):
pb.update(ii * 2 + jj + 1)
return x, pb.idx
def test_progressbar_parallel_more(capsys):
"""Test ProgressBar with parallel computing, advanced version."""
assert capsys.readouterr().out == ""
# This must be "1" because "capsys" won't get stdout properly otherwise
parallel, p_fun, _ = parallel_func(_identity_block_wide, n_jobs=1, verbose=False)
arr = np.arange(10)
with use_log_level(True):
with ProgressBar(len(arr) * 2) as pb:
out = parallel(
p_fun(x, pb.subset(pb_idx))
for pb_idx, x in array_split_idx(arr, 2, n_per_split=2)
)
idxs = np.concatenate([o[1] for o in out])
assert_array_equal(idxs, np.arange(len(arr) * 2))
out = np.concatenate([o[0] for o in out])
assert Path(pb._mmap_fname).is_file()
sum_ = np.memmap(
pb._mmap_fname, dtype="bool", mode="r", shape=len(arr) * 2
).sum()
assert sum_ == len(arr) * 2
assert not Path(pb._mmap_fname).is_file(), "__exit__ not called?"
cap = capsys.readouterr()
out = cap.err
assert "100%" in out