# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import re
import shutil
import zipfile
from functools import partial
from pathlib import Path
import pooch
import pytest
import mne.datasets._fsaverage.base
from mne import datasets, read_labels_from_annot, write_labels_to_annot
from mne.datasets import fetch_dataset, fetch_infant_template, fetch_phantom, testing
from mne.datasets._fsaverage.base import _set_montage_coreg_path
from mne.datasets._infant import base as infant_base
from mne.datasets._phantom import base as phantom_base
from mne.datasets.utils import _manifest_check_download
from mne.utils import (
ArgvSetter,
_pl,
catch_logging,
get_subjects_dir,
hashfunc,
requires_good_network,
use_log_level,
)
subjects_dir = testing.data_path(download=False) / "subjects"
def test_datasets_basic(tmp_path, monkeypatch):
"""Test simple dataset functions."""
# XXX 'hf_sef' and 'misc' do not conform to these standards
for dname in (
"sample",
"somato",
"spm_face",
"testing",
"opm",
"bst_raw",
"bst_auditory",
"bst_resting",
"multimodal",
"bst_phantom_ctf",
"bst_phantom_elekta",
"kiloword",
"mtrf",
"phantom_4dbti",
"visual_92_categories",
"fieldtrip_cmc",
):
if dname.startswith("bst"):
dataset = getattr(datasets.brainstorm, dname)
else:
dataset = getattr(datasets, dname)
if str(dataset.data_path(download=False)) != ".":
assert isinstance(dataset.get_version(), str)
assert datasets.has_dataset(dname)
else:
assert dataset.get_version() is None
assert not datasets.has_dataset(dname)
print(f"{dname}: {datasets.has_dataset(dname)}")
# Explicitly test one that isn't preset (given the config)
monkeypatch.setenv("MNE_DATASETS_SAMPLE_PATH", str(tmp_path))
dataset = datasets.sample
assert str(dataset.data_path(download=False)) == "."
assert dataset.get_version() != ""
assert dataset.get_version() is None
# don't let it read from the config file to get the directory,
# force it to look for the default
monkeypatch.setenv("_MNE_FAKE_HOME_DIR", str(tmp_path))
monkeypatch.delenv("SUBJECTS_DIR", raising=False)
assert datasets.utils._get_path(None, "foo", "bar") == tmp_path / "mne_data"
assert get_subjects_dir(None) is None
_set_montage_coreg_path()
sd = get_subjects_dir()
assert sd.name.endswith("MNE-fsaverage-data")
monkeypatch.setenv("MNE_DATA", str(tmp_path / "foo"))
with pytest.raises(FileNotFoundError, match="as specified by MNE_DAT"):
testing.data_path(download=False)
def noop(*args, **kwargs):
return
monkeypatch.setattr(mne.datasets._fsaverage.base, "_manifest_check_download", noop)
sd_2 = datasets.fetch_fsaverage()
assert sd / "fsaverage" == sd_2
@requires_good_network
def test_downloads(tmp_path, monkeypatch, capsys):
"""Test dataset URL and version handling."""
# Try actually downloading a dataset
kwargs = dict(path=tmp_path, verbose=True)
# XXX we shouldn't need to disable capsys here, but there's a pytest bug
# that we're hitting (https://github.com/pytest-dev/pytest/issues/5997)
# now that we use pooch
with capsys.disabled():
with pytest.raises(RuntimeError, match="Do not download .* in tests"):
path = datasets._fake.data_path(update_path=False, **kwargs)
monkeypatch.setattr(
datasets.utils, "_MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS", ()
)
path = datasets._fake.data_path(update_path=False, **kwargs)
assert path.is_dir()
assert (path / "bar").is_file()
assert not datasets.has_dataset("fake") # not in the desired path
assert datasets._fake.get_version() is None
assert datasets.utils._get_version("fake") is None
monkeypatch.setenv("_MNE_FAKE_HOME_DIR", str(tmp_path))
with pytest.warns(RuntimeWarning, match="non-standard config"):
new_path = datasets._fake.data_path(update_path=True, **kwargs)
assert path == new_path
out, _ = capsys.readouterr()
assert "Downloading" not in out
# No version: shown as existing but unknown version
assert datasets.has_dataset("fake")
# XXX logic bug, should be "unknown"
assert datasets._fake.get_version() == "0.0"
# With a version but no required one: shown as existing and gives version
fname = tmp_path / "foo" / "version.txt"
with open(fname, "w") as fid:
fid.write("0.1")
assert datasets.has_dataset("fake")
assert datasets._fake.get_version() == "0.1"
datasets._fake.data_path(download=False, **kwargs)
out, _ = capsys.readouterr()
assert "out of date" not in out
# With the required version: shown as existing with the required version
monkeypatch.setattr(datasets._fetch, "_FAKE_VERSION", "0.1")
assert datasets.has_dataset("fake")
assert datasets._fake.get_version() == "0.1"
datasets._fake.data_path(download=False, **kwargs)
out, _ = capsys.readouterr()
assert "out of date" not in out
monkeypatch.setattr(datasets._fetch, "_FAKE_VERSION", "0.2")
# With an older version:
# 1. Marked as not actually being present
assert not datasets.has_dataset("fake")
# 2. Will try to update when `data_path` gets called, with logged message
want_msg = "Correctly trying to download newer version"
def _error_download(self, fname, downloader, processor):
url = self.get_url(fname)
full_path = self.abspath / fname
assert "foo.tgz" in url
assert str(tmp_path) in str(full_path)
raise RuntimeError(want_msg)
monkeypatch.setattr(pooch.Pooch, "fetch", _error_download)
with pytest.raises(RuntimeError, match=want_msg):
datasets._fake.data_path(**kwargs)
out, _ = capsys.readouterr()
assert re.match(r".* 0\.1 .*out of date.* 0\.2.*", out, re.MULTILINE), out
# Hash mismatch suggestion
# https://mne.discourse.group/t/fsaverage-hash-value-mismatch/4663/3
want_msg = "MD5 hash of downloaded file (MNE-sample-data-processed.tar.gz) does not match the known hash: expected md5:e8f30c4516abdc12a0c08e6bae57409c but got a9dfc7e8843fd7f8a928901e12fb3d25. Deleted download for safety. The downloaded file may have been corrupted or the known hash may be outdated." # noqa: E501
def _error_download_2(self, fname, downloader, processor):
url = self.get_url(fname)
full_path = self.abspath / fname
assert "foo.tgz" in url
assert str(tmp_path) in str(full_path)
raise ValueError(want_msg)
shutil.rmtree(tmp_path / "foo")
monkeypatch.setattr(pooch.Pooch, "fetch", _error_download_2)
with pytest.raises(ValueError, match=".*known hash.*force_update=True.*"):
datasets._fake.data_path(download=True, force_update=True, **kwargs)
@pytest.mark.slowtest
@testing.requires_testing_data
@requires_good_network
def test_fetch_parcellations(tmp_path):
"""Test fetching parcellations."""
pytest.importorskip("nibabel")
this_subjects_dir = tmp_path
fsaverage_dir = this_subjects_dir / "fsaverage"
(fsaverage_dir / "label").mkdir(parents=True)
(fsaverage_dir / "surf").mkdir()
for hemi in ("lh", "rh"):
shutil.copyfile(
subjects_dir / "fsaverage" / "surf" / f"{hemi}.white",
fsaverage_dir / "surf" / f"{hemi}.white",
)
# speed up by prenteding we have one of them
with open(fsaverage_dir / "label" / "lh.aparc_sub.annot", "wb"):
pass
datasets.fetch_aparc_sub_parcellation(subjects_dir=this_subjects_dir)
with ArgvSetter(("--accept-hcpmmp-license",)):
datasets.fetch_hcp_mmp_parcellation(subjects_dir=this_subjects_dir)
for hemi in ("lh", "rh"):
assert (fsaverage_dir / "label" / f"{hemi}.aparc_sub.annot").is_file()
# test our annot round-trips here
kwargs = dict(
subject="fsaverage", hemi="both", sort=False, subjects_dir=this_subjects_dir
)
labels = read_labels_from_annot(parc="HCPMMP1", **kwargs)
write_labels_to_annot(
labels,
parc="HCPMMP1_round",
table_name="./left.fsaverage164.label.gii",
**kwargs,
)
orig = fsaverage_dir / "label" / "lh.HCPMMP1.annot"
first = hashfunc(orig)
new = str(orig)[:-6] + "_round.annot"
second = hashfunc(new)
assert first == second
_zip_fnames = ["foo/foo.txt", "foo/bar.txt", "foo/baz.txt"]
def _fake_zip_fetch(url, path, fname, *args, **kwargs):
path = Path(path)
assert isinstance(fname, str)
fname = path / fname
with zipfile.ZipFile(fname, "w") as zipf:
with zipf.open("foo/", "w"):
pass
for fname in _zip_fnames:
with zipf.open(fname, "w"):
pass
@pytest.mark.parametrize("n_have", range(len(_zip_fnames)))
def test_manifest_check_download(tmp_path, n_have, monkeypatch):
"""Test our manifest downloader."""
monkeypatch.setattr(pooch, "retrieve", _fake_zip_fetch)
destination = tmp_path / "empty"
manifest_path = tmp_path / "manifest.txt"
with open(manifest_path, "w") as fid:
for fname in _zip_fnames:
fid.write(f"{fname}\n")
assert n_have in range(len(_zip_fnames) + 1)
assert not destination.is_file()
if n_have > 0:
(destination / "foo").mkdir(parents=True)
assert (destination / "foo").is_dir()
for fname in _zip_fnames:
assert not (destination / fname).is_file()
for fname in _zip_fnames[:n_have]:
with open(destination / fname, "w"):
pass
with catch_logging() as log:
with use_log_level(True):
# we mock the pooch.retrieve so these are not used
url = hash_ = ""
_manifest_check_download(manifest_path, destination, url, hash_)
log = log.getvalue()
n_missing = 3 - n_have
assert (f"{n_missing} file{_pl(n_missing)} missing from") in log
for want in ("Extracting missing", "Successfully "):
if n_missing > 0:
assert want in log
else:
assert want not in log
assert (destination).is_dir()
for fname in _zip_fnames:
assert (destination / fname).is_file()
def _fake_mcd(manifest_path, destination, url, hash_, name=None, fake_files=False):
if name is None:
name = url.split("/")[-1].split(".")[0]
assert name in url
assert name in str(destination)
assert name in str(manifest_path)
assert len(hash_) == 32
if fake_files:
with open(manifest_path) as fid:
for path in fid:
path = path.strip()
if not path:
continue
fname = destination / path
fname.parent.mkdir(exist_ok=True)
with open(fname, "wb"):
pass
def test_infant(tmp_path, monkeypatch):
"""Test fetch_infant_template."""
monkeypatch.setattr(infant_base, "_manifest_check_download", _fake_mcd)
fetch_infant_template("12mo", subjects_dir=tmp_path)
with pytest.raises(ValueError, match="Invalid value for"):
fetch_infant_template("0mo", subjects_dir=tmp_path)
def test_phantom(tmp_path, monkeypatch):
"""Test phantom data downloading."""
# The Otaniemi file is only ~6MB, so in principle maybe we could test
# an actual download here. But it doesn't seem worth it given that
# CircleCI will at least test the VectorView one, and this file should
# not change often.
monkeypatch.setattr(
phantom_base,
"_manifest_check_download",
partial(_fake_mcd, name="phantom_otaniemi", fake_files=True),
)
fetch_phantom("otaniemi", subjects_dir=tmp_path)
assert (tmp_path / "phantom_otaniemi" / "mri" / "T1.mgz").is_file()
@requires_good_network
def test_fetch_uncompressed_file(tmp_path):
"""Test downloading an uncompressed file with our fetch function."""
dataset_dict = dict(
dataset_name="license",
url="https://raw.githubusercontent.com/mne-tools/mne-python/main/LICENSE.txt",
archive_name="LICENSE.foo",
folder_name=tmp_path / "foo",
hash=None,
)
fetch_dataset(dataset_dict, path=None, force_update=True)
assert (tmp_path / "foo" / "LICENSE.foo").is_file()