"""Some utility functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import logging
import os
import os.path as op
import tempfile
import time
from collections.abc import Iterable
from threading import Thread
import numpy as np
from ._logging import logger
from .check import _check_option
from .config import get_config
class ProgressBar:
"""Generate a command-line progressbar.
Parameters
----------
iterable : iterable | int | None
The iterable to use. Can also be an int for backward compatibility
(acts like ``max_value``).
initial_value : int
Initial value of process, useful when resuming process from a specific
value, defaults to 0.
mesg : str
Message to include at end of progress bar.
max_total_width : int | str
Maximum total message width. Can use "auto" (default) to try to set
a sane value based on the current terminal width.
max_value : int | None
The max value. If None, the length of ``iterable`` will be used.
which_tqdm : str | None
Which tqdm module to use. Can be "tqdm", "tqdm.notebook", or "off".
Defaults to ``None``, which uses the value of the MNE_TQDM environment
variable, or ``"tqdm.auto"`` if that is not set.
**kwargs : dict
Additional keyword arguments for tqdm.
"""
def __init__(
self,
iterable=None,
initial_value=0,
mesg=None,
max_total_width="auto",
max_value=None,
*,
which_tqdm=None,
**kwargs,
):
# The following mimics this, but with configurable module to use
# from ..externals.tqdm import auto
import tqdm
if which_tqdm is None:
which_tqdm = get_config("MNE_TQDM", "tqdm.auto")
_check_option(
"MNE_TQDM", which_tqdm[:5], ("tqdm", "tqdm.", "off"), extra="beginning"
)
logger.debug(f"Using ProgressBar with {which_tqdm}")
if which_tqdm not in ("tqdm", "off"):
try:
__import__(which_tqdm)
except Exception as exc:
raise ValueError(
f"Unknown tqdm backend {repr(which_tqdm)}, got: {exc}"
) from None
tqdm = getattr(tqdm, which_tqdm.split(".", 1)[1])
tqdm = tqdm.tqdm
defaults = dict(
leave=True,
mininterval=0.016,
miniters=1,
smoothing=0.05,
bar_format="{percentage:3.0f}%|{bar}| {desc} : {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt:>11}{postfix}]", # noqa: E501
)
for key, val in defaults.items():
if key not in kwargs:
kwargs.update({key: val})
if isinstance(iterable, Iterable):
self.iterable = iterable
if max_value is None:
self.max_value = len(iterable)
else:
self.max_value = max_value
else: # ignore max_value then
self.max_value = int(iterable)
self.iterable = None
if max_total_width == "auto":
max_total_width = None # tqdm's auto
with tempfile.NamedTemporaryFile("wb", prefix="tmp_mne_prog") as tf:
self._mmap_fname = tf.name
del tf # should remove the file
self._mmap = None
disable = logger.level > logging.INFO or which_tqdm == "off"
self._tqdm = tqdm(
iterable=self.iterable,
desc=mesg,
total=self.max_value,
initial=initial_value,
ncols=max_total_width,
disable=disable,
**kwargs,
)
def update(self, cur_value):
"""Update progressbar with current value of process.
Parameters
----------
cur_value : number
Current value of process. Should be <= max_value (but this is not
enforced). The percent of the progressbar will be computed as
``(cur_value / max_value) * 100``.
"""
self.update_with_increment_value(cur_value - self._tqdm.n)
def update_with_increment_value(self, increment_value):
"""Update progressbar with an increment.
Parameters
----------
increment_value : int
Value of the increment of process. The percent of the progressbar
will be computed as
``(self.cur_value + increment_value / max_value) * 100``.
"""
try:
self._tqdm.update(increment_value)
except TypeError: # can happen during GC on Windows
pass
def __iter__(self):
"""Iterate to auto-increment the pbar with 1."""
yield from self._tqdm
def subset(self, idx):
"""Make a joblib-friendly index subset updater.
Parameters
----------
idx : ndarray
List of indices for this subset.
Returns
-------
updater : instance of PBSubsetUpdater
Class with a ``.update(ii)`` method.
"""
return _PBSubsetUpdater(self, idx)
def __enter__(self): # noqa: D105
# This should only be used with pb.subset and parallelization
if op.isfile(self._mmap_fname):
os.remove(self._mmap_fname)
# prevent corner cases where self.max_value == 0
self._mmap = np.memmap(
self._mmap_fname, bool, "w+", shape=max(self.max_value, 1)
)
self.update(0) # must be zero as we just created the memmap
# We need to control how the pickled bars exit: remove print statements
self._thread = _UpdateThread(self)
self._thread.start()
return self
def __exit__(self, type_, value, traceback): # noqa: D105
# Restore exit behavior for our one from the main thread
self.update(self._mmap.sum())
self._tqdm.close()
self._thread._mne_run = False
self._thread.join()
self._mmap = None
if op.isfile(self._mmap_fname):
try:
os.remove(self._mmap_fname)
# happens on Windows sometimes
except PermissionError: # pragma: no cover
pass
def __del__(self):
"""Ensure output completes."""
if getattr(self, "_tqdm", None) is not None:
self._tqdm.close()
class _UpdateThread(Thread):
def __init__(self, pb):
super().__init__(daemon=True)
self._mne_run = True
self._mne_pb = pb
def run(self):
while self._mne_run:
self._mne_pb.update(self._mne_pb._mmap.sum())
time.sleep(1.0 / 30.0) # 30 Hz refresh is plenty
class _PBSubsetUpdater:
def __init__(self, pb, idx):
self.mmap = pb._mmap
self.idx = idx
def update(self, ii):
self.mmap[self.idx[ii - 1]] = True